Support overlapping two batches (#4068)

This commit is contained in:
fzyzcjy
2025-05-25 08:39:07 +08:00
committed by GitHub
parent f456037396
commit 0d47788025
13 changed files with 1145 additions and 129 deletions

View File

@@ -0,0 +1,241 @@
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
from sglang.srt import two_batch_overlap
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
class TboAttnBackend(AttentionBackend):
def __init__(self, primary: AttentionBackend, children: List[AttentionBackend]):
super().__init__()
self.primary = primary
self.children = children
@classmethod
def init_new(cls, creator: Callable[[], AttentionBackend]):
return cls(
primary=creator(),
children=[creator() for _ in range(2)],
)
def init_forward_metadata(self, forward_batch: "ForwardBatch"):
self.primary.init_forward_metadata(forward_batch=forward_batch)
if forward_batch.tbo_children is not None:
for child, forward_batch_child in zip(
self.children, forward_batch.tbo_children, strict=True
):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.primary.init_forward_metadata_capture_cuda_graph(
bs=bs,
num_tokens=num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
)
self._init_forward_metadata_cuda_graph_children(
fn_name="init_forward_metadata_capture_cuda_graph",
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
capture_num_tokens=num_tokens,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.primary.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
seq_lens_cpu=seq_lens_cpu,
)
self._init_forward_metadata_cuda_graph_children(
fn_name="init_forward_metadata_replay_cuda_graph",
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
replay_seq_lens_sum=seq_lens_sum,
replay_seq_lens_cpu=seq_lens_cpu,
)
def _init_forward_metadata_cuda_graph_children(
self,
fn_name: str,
# common args
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
from sglang.srt.model_executor.forward_batch_info import ForwardMode
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=num_tokens,
extend_lens=None,
)
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
)
num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left
bs_child_right = num_tokens_child_right
assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0
), f"{num_tokens_child_left=} {num_tokens_child_right=} {forward_mode=} {num_tokens=}"
common_pre_split_args = dict(
fn_name=fn_name,
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
capture_num_tokens=capture_num_tokens,
replay_seq_lens_sum=replay_seq_lens_sum,
replay_seq_lens_cpu=replay_seq_lens_cpu,
)
args_left = _init_forward_metadata_cuda_graph_split(
output_bs=bs_child_left,
seq_slice=slice(None, tbo_split_seq_index),
**common_pre_split_args,
)
args_right = _init_forward_metadata_cuda_graph_split(
output_bs=bs_child_right,
seq_slice=slice(tbo_split_seq_index, None),
**common_pre_split_args,
)
child_left, child_right = self.children
getattr(child_left, fn_name)(**args_left)
getattr(child_right, fn_name)(**args_right)
def get_cuda_graph_seq_len_fill_value(self):
ans = self.primary.get_cuda_graph_seq_len_fill_value()
for child in self.children:
assert ans == child.get_cuda_graph_seq_len_fill_value()
return ans
def forward_extend(self, *args, **kwargs):
return self.primary.forward_extend(*args, **kwargs)
def forward_decode(self, *args, **kwargs):
return self.primary.forward_decode(*args, **kwargs)
def _init_forward_metadata_cuda_graph_split(
fn_name: str,
seq_slice: slice,
output_bs: int,
# common args
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet"
ans = dict(
bs=output_bs,
req_pool_indices=req_pool_indices[seq_slice],
seq_lens=seq_lens[seq_slice],
# directly forward
forward_mode=forward_mode,
# ignore
encoder_lens=None,
spec_info=None,
)
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
ans.update(
dict(
num_tokens=output_bs,
)
)
elif fn_name == "init_forward_metadata_replay_cuda_graph":
output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice]
ans.update(
dict(
seq_lens_sum=output_seq_lens_cpu.sum().item(),
seq_lens_cpu=output_seq_lens_cpu,
)
)
else:
raise NotImplementedError
return ans