Files
sglang/python/sglang/srt/two_batch_overlap.py
Chaitanya Sri Krishna Lolla 323bc2f51a Enable TBO on ROCm (#8329)
2025-08-09 01:59:55 -07:00

989 lines
34 KiB
Python

from __future__ import annotations
import copy
import dataclasses
import logging
from dataclasses import replace
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import (
CommunicateContext,
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
compute_position,
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
_is_hip = is_hip()
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
logger = logging.getLogger(__name__)
# -------------------------------- Compute Basic Info ---------------------------------------
def get_token_num_per_seq(
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if forward_mode.is_target_verify():
return spec_info.draft_token_num
elif forward_mode.is_decode():
return 1
elif forward_mode.is_idle():
return 0
else:
# For extend, we should not use `token_num_per_seq`.
return None
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def compute_split_seq_index(
forward_mode: "ForwardMode",
num_tokens: int,
extend_lens: Optional[Sequence[int]],
token_num_per_seq: Optional[int],
) -> Optional[int]:
if forward_mode == ForwardMode.EXTEND:
assert extend_lens is not None
return _split_extend_seqs(extend_lens)
elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None
return (num_tokens // token_num_per_seq) // 2
elif forward_mode.is_idle():
assert num_tokens == 0
return 0
else:
raise NotImplementedError()
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
if extend_lens is None:
return False
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold
)
def _split_extend_seqs(arr: Sequence[int]) -> int:
if _is_two_chunk_split_enabled(arr):
return _split_array_by_cum_less_than_half(arr)
return _split_array_by_balanced_sum(arr)
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
left_sum = 0
overall_sum = sum(arr)
half_sum = overall_sum // 2
chosen_index = 0
for i in range(len(arr)):
left_sum += arr[i]
if left_sum > half_sum:
chosen_index = i
break
return chosen_index
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
overall_sum = sum(arr)
left_sum = 0
min_diff = float("inf")
best_index = 0
for i in range(1, len(arr)):
left_sum += arr[i - 1]
right_sum = overall_sum - left_sum
diff = abs(left_sum - right_sum)
if diff <= min_diff:
min_diff = diff
best_index = i
else:
break
return best_index
def _update_device_and_sum_field_from_cpu_field(
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
):
cpu_value = getattr(batch, cpu_field, None)
old_device_value = getattr(batch, device_field, None)
if (
cpu_value is None
or old_device_value is None
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
):
return
new_device_value = (
cpu_value
if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True)
setattr(batch, device_field, new_device_value)
if sum_field is not None:
sum_value = (
cpu_value.sum().item()
if isinstance(cpu_value, torch.Tensor)
else sum(cpu_value)
)
setattr(batch, sum_field, sum_value)
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
if seq_index == 0:
return 0
offset = 0
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
for i in range(max_seq_len):
offset += (
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
) * spec_info.draft_token_num
return offset
def split_spec_info(
spec_info: Optional[EagleVerifyInput],
start_seq_index: int,
end_seq_index: int,
start_token_index: int,
end_token_index: int,
):
if spec_info is None:
return None
if spec_info.draft_token is not None:
draft_token = spec_info.draft_token[start_token_index:end_token_index]
else:
draft_token = None
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
custom_mask_end = spec_info.custom_mask.shape[0]
else:
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
if custom_mask_end > custom_mask_start:
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
else:
custom_mask = spec_info.custom_mask
else:
custom_mask = spec_info.custom_mask
if spec_info.positions is not None:
positions = spec_info.positions[start_token_index:end_token_index]
else:
positions = None
if spec_info.retrive_index is not None:
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
else:
retrive_index = None
if spec_info.retrive_next_token is not None:
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
else:
retrive_next_token = None
if spec_info.retrive_next_sibling is not None:
retrive_next_sibling = spec_info.retrive_next_sibling[
start_seq_index:end_seq_index
]
else:
retrive_next_sibling = None
if spec_info.retrive_cum_len is not None:
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
else:
retrive_cum_len = None
if spec_info.seq_lens_cpu is not None:
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
else:
seq_lens_cpu = None
if seq_lens_cpu is not None:
seq_lens_sum = seq_lens_cpu.sum()
else:
seq_lens_sum = None
output_spec_info = replace(
spec_info,
custom_mask=custom_mask,
draft_token=draft_token,
positions=positions,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=retrive_cum_len,
seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=seq_lens_sum,
)
return output_spec_info
def compute_split_token_index(
split_seq_index: int,
forward_mode: "ForwardMode",
extend_seq_lens: Optional[Sequence[int]],
token_num_per_seq: Optional[int],
) -> int:
if forward_mode == ForwardMode.EXTEND:
assert extend_seq_lens is not None
if _is_two_chunk_split_enabled(extend_seq_lens):
return sum(extend_seq_lens) // 2
return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None
return split_seq_index * token_num_per_seq
elif forward_mode.is_idle():
assert split_seq_index == 0
return 0
else:
raise NotImplementedError
def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
tbo_split_seq_index = compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=cuda_graph_num_tokens,
extend_lens=None,
token_num_per_seq=token_num_per_seq,
)
tbo_split_token_index = compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
token_num_per_seq=token_num_per_seq,
)
return tbo_split_seq_index, tbo_split_token_index
# -------------------------------- Preparation ---------------------------------------
class TboCudaGraphRunnerPlugin:
def __init__(self):
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
return
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
)
batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode,
num_tokens=num_tokens,
extend_lens=None,
token_num_per_seq=token_num_per_seq,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
self._tbo_children_num_token_non_padded[...] = (
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
)
TboForwardBatchPreparer.prepare_raw(
batch,
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
)
def replay_prepare(
self,
forward_mode: ForwardMode,
bs: int,
num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
cuda_graph_num_tokens=bs * token_num_per_seq,
spec_info=spec_info,
)
)
self._tbo_children_num_token_non_padded[...] = (
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
tbo_split_token_index=tbo_split_token_index,
num_token_non_padded=num_token_non_padded,
)
)
class TboDPAttentionPreparer:
def prepare_all_gather(
self,
local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
):
self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None:
token_num_per_seq = get_token_num_per_seq(
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
)
if (
local_batch.forward_mode.is_target_verify()
or local_batch.forward_mode.is_decode()
):
num_tokens = local_batch.batch_size() * token_num_per_seq
else:
num_tokens = local_batch.extend_num_tokens
self.local_tbo_split_seq_index = compute_split_seq_index(
forward_mode=local_batch.forward_mode,
num_tokens=num_tokens,
extend_lens=local_batch.extend_lens,
token_num_per_seq=token_num_per_seq,
)
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
(
local_batch.forward_mode.is_extend()
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
)
else:
self.local_tbo_split_seq_index = 0
local_can_run_tbo = True
local_forward_mode = self._compute_local_forward_mode(local_batch)
return local_can_run_tbo, local_forward_mode
def compute_output(self, partial_global_info):
local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist())
forward_modes = partial_global_info[:, 0, 1].tolist()
global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
forward_modes
)
can_run_tbo = (
self.enable_two_batch_overlap
and local_can_run_tbo_aggregated
and forward_mode_agree
)
tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
global_forward_mode = global_forward_mode if can_run_tbo else None
return tbo_split_seq_index, global_forward_mode
@staticmethod
def _compute_local_forward_mode(local_batch):
return (
local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
).value
@staticmethod
def _compute_global_forward_mode(forward_modes):
forward_modes_excluding_idle = [
x for x in forward_modes if x != ForwardMode.IDLE.value
]
if not forward_modes_excluding_idle:
return ForwardMode.IDLE, False
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
forward_modes_excluding_idle
)
global_forward_mode = (
ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
)
return global_forward_mode, forward_mode_agree
@staticmethod
def _is_all_same(x):
return all(value == x[0] for value in x)
class TboForwardBatchPreparer:
@classmethod
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
if batch.tbo_split_seq_index is None or is_draft_worker:
return
tbo_children_num_token_non_padded = (
cls.compute_tbo_children_num_token_non_padded(batch)
)
cls.prepare_raw(
batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
)
@classmethod
def prepare_raw(
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
):
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
tbo_split_token_index = cls._compute_split_token_index(batch)
is_enable_two_chunk = (
batch.forward_mode == ForwardMode.EXTEND
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
)
if _tbo_debug:
logger.info(
f"TboForwardBatchPreparer.prepare "
f"is_enable_two_chunk={is_enable_two_chunk} "
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
f"tbo_split_token_index={tbo_split_token_index} "
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
f"bs={batch.batch_size} "
f"forward_mode={batch.forward_mode}"
)
assert isinstance(batch.attn_backend, TboAttnBackend)
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
tbo_children_num_token_non_padded
)
child_a = cls.filter_batch(
batch,
start_token_index=0,
end_token_index=tbo_split_token_index,
start_seq_index=0,
end_seq_index=(
batch.tbo_split_seq_index + 1
if is_enable_two_chunk
else batch.tbo_split_seq_index
),
output_attn_backend=attn_backend_child_a,
out_num_token_non_padded=out_num_token_non_padded_a,
)
child_b = cls.filter_batch(
batch,
start_token_index=tbo_split_token_index,
end_token_index=batch.input_ids.shape[0],
start_seq_index=batch.tbo_split_seq_index,
end_seq_index=batch.batch_size,
output_attn_backend=attn_backend_child_b,
out_num_token_non_padded=out_num_token_non_padded_b,
)
if is_enable_two_chunk:
cls.derive_fields_related_to_seq_len_for_two_chunk(
batch,
child_a=child_a,
child_b=child_b,
tbo_split_seq_index=batch.tbo_split_seq_index,
)
assert batch.tbo_children is None
batch.tbo_children = [child_a, child_b]
@classmethod
def derive_fields_related_to_seq_len_for_two_chunk(
cls,
batch: ForwardBatch,
*,
child_a: ForwardBatch,
child_b: ForwardBatch,
tbo_split_seq_index: int,
):
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
half_seq_lens_sum = overall_seq_lens_sum // 2
left_last_seq_token_num = half_seq_lens_sum - sum(
extend_seq_lens_cpu[:tbo_split_seq_index]
)
right_first_seq_token_num = (
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
)
# making deepcopy to be extra safe
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
for child in [child_a, child_b]:
_update_device_and_sum_field_from_cpu_field(
batch=child,
cpu_field="extend_seq_lens_cpu",
device_field="extend_seq_lens",
sum_field="extend_num_tokens",
)
assert (
child_a.extend_num_tokens == half_seq_lens_sum
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
child_a.seq_lens_cpu[-1] = (
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
)
_update_device_and_sum_field_from_cpu_field(
batch=child_a,
cpu_field="seq_lens_cpu",
device_field="seq_lens",
sum_field="seq_lens_sum",
)
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
_update_device_and_sum_field_from_cpu_field(
batch=child_b,
cpu_field="extend_prefix_lens_cpu",
device_field="extend_prefix_lens",
sum_field=None,
)
_, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"],
child_b.extend_prefix_lens,
child_b.extend_seq_lens,
child_b.extend_num_tokens,
)
@classmethod
def filter_batch(
cls,
batch: ForwardBatch,
*,
start_token_index: int,
end_token_index: int,
start_seq_index: int,
end_seq_index: int,
output_attn_backend: AttentionBackend,
out_num_token_non_padded: torch.Tensor,
):
assert (
end_token_index >= start_token_index
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size
output_dict = dict()
for key in [
"input_ids",
"positions",
"out_cache_loc",
]:
old_value = getattr(batch, key)
assert (
old_value.shape[0] == num_tokens
), f"{key=} {old_value=} {num_tokens=} {batch=}"
output_dict[key] = old_value[start_token_index:end_token_index]
for key in [
"req_pool_indices",
"seq_lens",
"seq_lens_cpu",
"extend_seq_lens",
"extend_prefix_lens",
"extend_start_loc",
"extend_prefix_lens_cpu",
"extend_seq_lens_cpu",
"extend_logprob_start_lens_cpu",
"lora_ids",
]:
old_value = getattr(batch, key)
if old_value is None:
continue
elif batch.forward_mode.is_target_verify() and (
key == "extend_seq_lens"
or key == "extend_prefix_lens"
or key == "extend_start_loc"
or key == "extend_prefix_lens_cpu"
or key == "extend_seq_lens_cpu"
or key == "extend_logprob_start_lens_cpu"
):
output_dict[key] = None
continue
assert (
len(old_value) == num_seqs
), f"{key=} {old_value=} {num_seqs=} {batch=}"
output_dict[key] = old_value[start_seq_index:end_seq_index]
spec_info = getattr(batch, "spec_info")
output_spec_info = split_spec_info(
spec_info=spec_info,
start_token_index=start_token_index,
end_token_index=end_token_index,
start_seq_index=start_seq_index,
end_seq_index=end_seq_index,
)
output_dict["spec_info"] = output_spec_info
for key in [
"forward_mode",
"is_extend_in_batch",
"return_logprob",
"req_to_token_pool",
"token_to_kv_pool",
"can_run_dp_cuda_graph",
"global_forward_mode",
"spec_algorithm",
"capture_hidden_mode",
"padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care
"split_index", # for split prefill
"orig_seq_lens", # only used by qwen-1m, thus not care
]:
output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify():
assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens
), f"{batch=}"
extend_num_tokens = _compute_extend_num_tokens(
output_dict["input_ids"], output_dict["forward_mode"]
)
# TODO improve, e.g. unify w/ `init_raw`
if (
global_server_args_dict["moe_dense_tp_size"] == 1
and batch.gathered_buffer is not None
):
sum_len = end_token_index - start_token_index
gathered_buffer = torch.zeros(
(sum_len, batch.gathered_buffer.shape[1]),
dtype=batch.gathered_buffer.dtype,
device=batch.gathered_buffer.device,
)
else:
gathered_buffer = None
output_dict.update(
dict(
batch_size=end_seq_index - start_seq_index,
seq_lens_sum=(
output_dict["seq_lens_cpu"].sum()
if "seq_lens_cpu" in output_dict
else None
),
extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend,
num_token_non_padded=out_num_token_non_padded,
tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None,
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
dp_padding_mode=None,
gathered_buffer=gathered_buffer,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
sampling_info=None,
# For logits and logprobs post processing, thus we do not care
temp_scaled_logprobs=False,
temperature=None,
top_p_normalized_logprobs=False,
top_p=None,
mm_inputs=None,
top_logprobs_nums=None,
token_ids_logprobs=None,
next_token_logits_buffer=None,
)
)
errors = []
for field in dataclasses.fields(ForwardBatch):
if getattr(batch, field.name) is not None and field.name not in output_dict:
errors.append(
f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
)
if len(errors) > 0:
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
return ForwardBatch(**output_dict)
@classmethod
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
return cls.compute_tbo_children_num_token_non_padded_raw(
tbo_split_token_index=cls._compute_split_token_index(batch),
num_token_non_padded=len(batch.input_ids),
)
@classmethod
def compute_tbo_children_num_token_non_padded_raw(
cls, tbo_split_token_index: int, num_token_non_padded: int
):
# TODO we may make padding on both sub-batches to make it slightly more balanced
value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
device=global_server_args_dict["device"], non_blocking=True
)
@classmethod
def _compute_split_token_index(cls, batch: ForwardBatch):
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
)
return compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
token_num_per_seq=token_num_per_seq,
)
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if (
forward_mode.is_decode()
or forward_mode.is_idle()
or forward_mode.is_target_verify()
):
return None
elif forward_mode.is_extend():
return input_ids.shape[0]
raise NotImplementedError
# -------------------------------- Execution ---------------------------------------
def model_forward_maybe_tbo(
layers,
enable_tbo: bool,
positions: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states: torch.Tensor,
input_data_scatter_mode: ScatterMode,
residual: Optional[torch.Tensor],
zero_allocator: Optional[BumpAllocator] = None,
):
inputs = dict(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
zero_allocator=zero_allocator,
)
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
operations_strategy = OperationsStrategy.init_new_tbo(
layers, forward_batch.global_forward_mode
)
if enable_tbo:
return _model_forward_tbo(
inputs=inputs,
operations_strategy=operations_strategy,
input_data_scatter_mode=input_data_scatter_mode,
layer_input_scatter_mode=layer_input_scatter_mode,
)
else:
return _model_forward_non_tbo(inputs, operations_strategy)
def _model_forward_tbo(
inputs,
operations_strategy: OperationsStrategy,
input_data_scatter_mode: ScatterMode,
layer_input_scatter_mode: ScatterMode,
):
inputs_arr = _model_forward_tbo_split_inputs(
**inputs,
input_data_scatter_mode=input_data_scatter_mode,
layer_input_scatter_mode=layer_input_scatter_mode,
)
del inputs
context = (
empty_context()
if _is_hip
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
operations_strategy.deep_gemm_num_sms
)
)
with context:
outputs_arr = execute_overlapped_operations(
inputs_arr=inputs_arr,
operations_arr=[operations_strategy.operations] * 2,
delta_stages=[0, operations_strategy.tbo_delta_stages],
)
return _model_forward_tbo_merge_outputs(*outputs_arr)
def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
outputs = execute_operations(inputs, operations_strategy.operations)
return outputs["hidden_states"], outputs["residual"]
def _model_forward_tbo_split_inputs(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: Optional[BumpAllocator],
input_data_scatter_mode: ScatterMode,
layer_input_scatter_mode: ScatterMode,
) -> List[Dict]:
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
context = CommunicateContext.init_new()
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
hidden_states_input_mode=input_data_scatter_mode,
residual_input_mode=input_data_scatter_mode,
output_mode=tbo_splitter_scatter_mode,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
context=context,
)
inputs_arr = _model_forward_tbo_split_inputs_raw(
hidden_states=hidden_states,
residual=residual,
positions=positions,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
hidden_states_input_mode=tbo_splitter_scatter_mode,
residual_input_mode=tbo_splitter_scatter_mode,
output_mode=layer_input_scatter_mode,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
context=context,
)
return dict(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
**kwargs,
)
return [_post_transform(**inputs) for inputs in inputs_arr]
def _model_forward_tbo_split_inputs_raw(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: Optional[BumpAllocator],
) -> List[Dict]:
return [
dict(
**_model_forward_filter_inputs(
hidden_states=hidden_states,
residual=residual,
positions=positions,
output_forward_batch=output_forward_batch,
tbo_subbatch_index=tbo_subbatch_index,
),
**(
dict(zero_allocator=zero_allocator)
if zero_allocator is not None
else {}
),
)
for tbo_subbatch_index, output_forward_batch in enumerate(
forward_batch.tbo_children
)
]
def _model_forward_filter_inputs(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
output_forward_batch: ForwardBatch,
tbo_subbatch_index: int,
) -> Dict:
token_slice = slice(*output_forward_batch.tbo_parent_token_range)
return dict(
hidden_states=hidden_states[token_slice],
residual=None if residual is None else residual[token_slice],
positions=positions[token_slice],
forward_batch=output_forward_batch,
tbo_subbatch_index=tbo_subbatch_index,
)
def _model_forward_tbo_merge_outputs(output_a, output_b):
def _handle_key(name):
value_a = output_a[name]
value_b = output_b[name]
assert (value_a is None) == (value_b is None)
if value_a is None:
return None
return torch.concat([value_a, value_b], dim=0)
return _handle_key("hidden_states"), _handle_key("residual")
# -------------------------------- Utilities and wrappers ---------------------------------------
class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs):
num_inner_dispatchers = (
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
)
self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
def dispatch(self, **kwargs) -> DispatchOutput:
return self._execute("dispatch", **kwargs)
def dispatch_a(self, **kwargs):
return self._execute("dispatch_a", **kwargs)
def dispatch_b(self, **kwargs):
return self._execute("dispatch_b", **kwargs)
def combine(self, **kwargs) -> torch.Tensor:
return self._execute("combine", **kwargs)
def combine_a(self, **kwargs):
return self._execute("combine_a", **kwargs)
def combine_b(self, **kwargs):
return self._execute("combine_b", **kwargs)