471 lines
16 KiB
Python
471 lines
16 KiB
Python
import dataclasses
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
|
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
|
from sglang.srt.operations_strategy import OperationsStrategy
|
|
from sglang.srt.utils import BumpAllocator, DeepEPMode
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|
|
|
|
|
# -------------------------------- Compute Basic Info ---------------------------------------
|
|
|
|
|
|
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
|
def compute_split_seq_index(
|
|
forward_mode: "ForwardMode",
|
|
num_tokens: int,
|
|
extend_lens: Optional[Sequence[int]],
|
|
) -> Optional[int]:
|
|
if forward_mode.is_extend():
|
|
assert extend_lens is not None
|
|
return _split_array_by_half_sum(extend_lens)
|
|
elif forward_mode.is_decode():
|
|
return num_tokens // 2
|
|
elif forward_mode.is_idle():
|
|
assert num_tokens == 0
|
|
return 0
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
|
overall_sum = sum(arr)
|
|
left_sum = 0
|
|
min_diff = float("inf")
|
|
best_index = 0
|
|
|
|
for i in range(1, len(arr)):
|
|
left_sum += arr[i - 1]
|
|
right_sum = overall_sum - left_sum
|
|
diff = abs(left_sum - right_sum)
|
|
if diff <= min_diff:
|
|
min_diff = diff
|
|
best_index = i
|
|
else:
|
|
break
|
|
|
|
return best_index
|
|
|
|
|
|
def compute_split_token_index(
|
|
split_seq_index: int,
|
|
forward_mode: "ForwardMode",
|
|
extend_seq_lens: Optional[Sequence[int]],
|
|
) -> int:
|
|
if forward_mode.is_extend():
|
|
assert extend_seq_lens is not None
|
|
return sum(extend_seq_lens[:split_seq_index])
|
|
elif forward_mode.is_decode():
|
|
return split_seq_index
|
|
elif forward_mode.is_idle():
|
|
assert split_seq_index == 0
|
|
return 0
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
# -------------------------------- Preparation ---------------------------------------
|
|
|
|
|
|
class TboCudaGraphRunnerUtils:
|
|
@staticmethod
|
|
def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int):
|
|
if that.model_runner.server_args.enable_two_batch_overlap:
|
|
tbo_split_seq_index = compute_split_seq_index(
|
|
forward_mode=that.capture_forward_mode,
|
|
num_tokens=num_tokens,
|
|
extend_lens=None,
|
|
)
|
|
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
|
assert (
|
|
tbo_split_seq_index is not None
|
|
), f"{that.capture_forward_mode=} {num_tokens=}"
|
|
else:
|
|
tbo_split_seq_index = None
|
|
return tbo_split_seq_index
|
|
|
|
|
|
class TboDPAttentionPreparer:
|
|
def prepare_all_gather(
|
|
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
|
|
):
|
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
|
|
|
if local_batch is not None:
|
|
self.local_tbo_split_seq_index = compute_split_seq_index(
|
|
forward_mode=local_batch.forward_mode,
|
|
num_tokens=local_batch.input_ids.shape[0],
|
|
extend_lens=local_batch.extend_lens,
|
|
)
|
|
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
|
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
|
local_batch.forward_mode.is_extend()
|
|
and enable_deepep_moe
|
|
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
|
)
|
|
else:
|
|
self.local_tbo_split_seq_index = 0
|
|
local_can_run_tbo = True
|
|
|
|
local_forward_mode = self._compute_local_forward_mode(local_batch)
|
|
|
|
return local_can_run_tbo, local_forward_mode
|
|
|
|
def compute_output(self, partial_global_info):
|
|
local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist())
|
|
forward_modes = partial_global_info[:, 0, 1].tolist()
|
|
|
|
global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
|
|
forward_modes
|
|
)
|
|
|
|
can_run_tbo = (
|
|
self.enable_two_batch_overlap
|
|
and local_can_run_tbo_aggregated
|
|
and forward_mode_agree
|
|
)
|
|
|
|
tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
|
|
global_forward_mode = global_forward_mode if can_run_tbo else None
|
|
return tbo_split_seq_index, global_forward_mode
|
|
|
|
@staticmethod
|
|
def _compute_local_forward_mode(local_batch):
|
|
return (
|
|
local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
|
|
).value
|
|
|
|
@staticmethod
|
|
def _compute_global_forward_mode(forward_modes):
|
|
converted_forward_modes = [
|
|
ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x
|
|
for x in forward_modes
|
|
]
|
|
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
|
|
converted_forward_modes
|
|
)
|
|
global_forward_mode = (
|
|
ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None
|
|
)
|
|
return global_forward_mode, forward_mode_agree
|
|
|
|
@staticmethod
|
|
def _is_all_same(x):
|
|
return all(value == x[0] for value in x)
|
|
|
|
|
|
class TboForwardBatchPreparer:
|
|
@classmethod
|
|
def prepare(cls, batch: ForwardBatch):
|
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
|
|
|
if batch.tbo_split_seq_index is None:
|
|
return
|
|
|
|
tbo_split_token_index = compute_split_token_index(
|
|
split_seq_index=batch.tbo_split_seq_index,
|
|
forward_mode=batch.forward_mode,
|
|
extend_seq_lens=batch.extend_seq_lens_cpu,
|
|
)
|
|
|
|
assert isinstance(batch.attn_backend, TboAttnBackend)
|
|
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
|
|
|
|
child_a = cls.filter_batch(
|
|
batch,
|
|
start_token_index=0,
|
|
end_token_index=tbo_split_token_index,
|
|
start_seq_index=0,
|
|
end_seq_index=batch.tbo_split_seq_index,
|
|
output_attn_backend=attn_backend_child_a,
|
|
)
|
|
child_b = cls.filter_batch(
|
|
batch,
|
|
start_token_index=tbo_split_token_index,
|
|
end_token_index=batch.input_ids.shape[0],
|
|
start_seq_index=batch.tbo_split_seq_index,
|
|
end_seq_index=batch.batch_size,
|
|
output_attn_backend=attn_backend_child_b,
|
|
)
|
|
|
|
assert batch.tbo_children is None
|
|
batch.tbo_children = [child_a, child_b]
|
|
|
|
@classmethod
|
|
def filter_batch(
|
|
cls,
|
|
batch: ForwardBatch,
|
|
*,
|
|
start_token_index: int,
|
|
end_token_index: int,
|
|
start_seq_index: int,
|
|
end_seq_index: int,
|
|
output_attn_backend: AttentionBackend,
|
|
):
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
|
|
num_tokens = batch.input_ids.shape[0]
|
|
num_seqs = batch.batch_size
|
|
|
|
output_dict = dict()
|
|
|
|
for key in [
|
|
"input_ids",
|
|
"positions",
|
|
"out_cache_loc",
|
|
]:
|
|
old_value = getattr(batch, key)
|
|
assert (
|
|
old_value.shape[0] == num_tokens
|
|
), f"{key=} {old_value=} {num_tokens=} {batch=}"
|
|
output_dict[key] = old_value[start_token_index:end_token_index]
|
|
|
|
for key in [
|
|
"req_pool_indices",
|
|
"seq_lens",
|
|
"seq_lens_cpu",
|
|
"extend_seq_lens",
|
|
"extend_prefix_lens",
|
|
"extend_start_loc",
|
|
"extend_prefix_lens_cpu",
|
|
"extend_seq_lens_cpu",
|
|
"extend_logprob_start_lens_cpu",
|
|
"lora_paths",
|
|
]:
|
|
old_value = getattr(batch, key)
|
|
if old_value is None:
|
|
continue
|
|
assert (
|
|
len(old_value) == num_seqs
|
|
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
|
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
|
|
|
for key in [
|
|
"forward_mode",
|
|
"return_logprob",
|
|
"req_to_token_pool",
|
|
"token_to_kv_pool",
|
|
"can_run_dp_cuda_graph",
|
|
"global_forward_mode",
|
|
"spec_info",
|
|
"spec_algorithm",
|
|
"capture_hidden_mode",
|
|
"padded_static_len",
|
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
|
]:
|
|
output_dict[key] = getattr(batch, key)
|
|
|
|
assert (
|
|
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
|
== batch.extend_num_tokens
|
|
), f"{batch=}"
|
|
extend_num_tokens = _compute_extend_num_tokens(
|
|
output_dict["input_ids"], output_dict["forward_mode"]
|
|
)
|
|
|
|
# TODO improve, e.g. unify w/ `init_raw`
|
|
if global_server_args_dict["moe_dense_tp_size"] == 1:
|
|
sum_len = end_token_index - start_token_index
|
|
gathered_buffer = torch.zeros(
|
|
(sum_len, batch.gathered_buffer.shape[1]),
|
|
dtype=batch.gathered_buffer.dtype,
|
|
device=batch.gathered_buffer.device,
|
|
)
|
|
else:
|
|
gathered_buffer = None
|
|
|
|
output_dict.update(
|
|
dict(
|
|
batch_size=end_seq_index - start_seq_index,
|
|
seq_lens_sum=(
|
|
output_dict["seq_lens_cpu"].sum()
|
|
if "seq_lens_cpu" in output_dict
|
|
else None
|
|
),
|
|
extend_num_tokens=extend_num_tokens,
|
|
attn_backend=output_attn_backend,
|
|
tbo_split_seq_index=None,
|
|
tbo_parent_token_range=(start_token_index, end_token_index),
|
|
tbo_children=None,
|
|
global_num_tokens_gpu=None,
|
|
global_num_tokens_cpu=None,
|
|
gathered_buffer=gathered_buffer,
|
|
global_num_tokens_for_logprob_gpu=None,
|
|
global_num_tokens_for_logprob_cpu=None,
|
|
sampling_info=None,
|
|
# For logits and logprobs post processing, thus we do not care
|
|
temp_scaled_logprobs=False,
|
|
temperature=None,
|
|
top_p_normalized_logprobs=False,
|
|
top_p=None,
|
|
mm_inputs=None,
|
|
num_token_non_padded=None,
|
|
)
|
|
)
|
|
|
|
errors = []
|
|
for field in dataclasses.fields(ForwardBatch):
|
|
if getattr(batch, field.name) is not None and field.name not in output_dict:
|
|
errors.append(
|
|
f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
|
|
)
|
|
if len(errors) > 0:
|
|
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
|
|
|
|
return ForwardBatch(**output_dict)
|
|
|
|
|
|
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
|
if forward_mode.is_extend():
|
|
return input_ids.shape[0]
|
|
elif forward_mode.is_decode() or forward_mode.is_idle():
|
|
return None
|
|
raise NotImplementedError
|
|
|
|
|
|
# -------------------------------- Execution ---------------------------------------
|
|
|
|
|
|
def model_forward_maybe_tbo(
|
|
layers,
|
|
enable_tbo: bool,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
zero_allocator: BumpAllocator,
|
|
):
|
|
inputs = dict(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
residual=residual,
|
|
zero_allocator=zero_allocator,
|
|
)
|
|
operations_strategy = OperationsStrategy.init_new_tbo(
|
|
layers, forward_batch.global_forward_mode
|
|
)
|
|
if enable_tbo:
|
|
return _model_forward_tbo(inputs, operations_strategy)
|
|
else:
|
|
return _model_forward_non_tbo(inputs, operations_strategy)
|
|
|
|
|
|
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy):
|
|
# The attn_tp_size!=1 case is not yet extracted to master
|
|
assert get_attention_tp_size() == 1
|
|
|
|
inputs_arr = _model_forward_tbo_split_inputs(**inputs)
|
|
del inputs
|
|
|
|
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
|
|
outputs_arr = execute_overlapped_operations(
|
|
inputs_arr=inputs_arr,
|
|
operations_arr=[operations_strategy.operations] * 2,
|
|
delta_stages=[0, operations_strategy.tbo_delta_stages],
|
|
)
|
|
|
|
return _model_forward_tbo_merge_outputs(*outputs_arr)
|
|
|
|
|
|
def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
|
|
outputs = execute_operations(inputs, operations_strategy.operations)
|
|
return outputs["hidden_states"], outputs["residual"]
|
|
|
|
|
|
def _model_forward_tbo_split_inputs(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
zero_allocator: BumpAllocator,
|
|
) -> List[Dict]:
|
|
return [
|
|
dict(
|
|
**_model_forward_filter_inputs(
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
positions=positions,
|
|
output_forward_batch=output_forward_batch,
|
|
tbo_subbatch_index=tbo_subbatch_index,
|
|
),
|
|
zero_allocator=zero_allocator,
|
|
)
|
|
for tbo_subbatch_index, output_forward_batch in enumerate(
|
|
forward_batch.tbo_children
|
|
)
|
|
]
|
|
|
|
|
|
def _model_forward_filter_inputs(
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
output_forward_batch: ForwardBatch,
|
|
tbo_subbatch_index: int,
|
|
) -> Dict:
|
|
token_slice = slice(*output_forward_batch.tbo_parent_token_range)
|
|
return dict(
|
|
hidden_states=hidden_states[token_slice],
|
|
residual=None if residual is None else residual[token_slice],
|
|
positions=positions[token_slice],
|
|
forward_batch=output_forward_batch,
|
|
tbo_subbatch_index=tbo_subbatch_index,
|
|
)
|
|
|
|
|
|
def _model_forward_tbo_merge_outputs(output_a, output_b):
|
|
def _handle_key(name):
|
|
value_a = output_a[name]
|
|
value_b = output_b[name]
|
|
assert (value_a is None) == (value_b is None)
|
|
if value_a is None:
|
|
return None
|
|
return torch.concat([value_a, value_b], dim=0)
|
|
|
|
return _handle_key("hidden_states"), _handle_key("residual")
|
|
|
|
|
|
# -------------------------------- Utilities and wrappers ---------------------------------------
|
|
|
|
|
|
class MaybeTboDeepEPDispatcher:
|
|
def __init__(self, **kwargs):
|
|
num_inner_dispatchers = (
|
|
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
|
|
)
|
|
self._inners = [
|
|
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
|
]
|
|
|
|
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
|
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
|
|
|
def dispatch(self, **kwargs):
|
|
return self._execute("dispatch", **kwargs)
|
|
|
|
def dispatch_a(self, **kwargs):
|
|
return self._execute("dispatch_a", **kwargs)
|
|
|
|
def dispatch_b(self, **kwargs):
|
|
return self._execute("dispatch_b", **kwargs)
|
|
|
|
def combine(self, **kwargs):
|
|
return self._execute("combine", **kwargs)
|
|
|
|
def combine_a(self, **kwargs):
|
|
return self._execute("combine_a", **kwargs)
|
|
|
|
def combine_b(self, **kwargs):
|
|
return self._execute("combine_b", **kwargs)
|