Minor refactor two-batch overlap (#6682)
This commit is contained in:
@@ -119,24 +119,15 @@ class TboAttnBackend(AttentionBackend):
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
num_tokens = bs
|
||||
|
||||
forward_mode_for_tbo_split = (
|
||||
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
||||
)
|
||||
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
|
||||
forward_mode=forward_mode_for_tbo_split,
|
||||
num_tokens=num_tokens,
|
||||
extend_lens=None,
|
||||
)
|
||||
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
|
||||
split_seq_index=tbo_split_seq_index,
|
||||
forward_mode=forward_mode_for_tbo_split,
|
||||
extend_seq_lens=None,
|
||||
tbo_split_seq_index, tbo_split_token_index = (
|
||||
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode=forward_mode,
|
||||
cuda_graph_num_tokens=num_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
num_tokens_child_left = tbo_split_token_index
|
||||
|
||||
Reference in New Issue
Block a user