feat: support compatibility between MTP and two-batch-overlap (#7225)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
)
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
num_tokens = bs
|
||||
assert (
|
||||
capture_num_tokens == bs * token_num_per_seq
|
||||
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
|
||||
num_tokens = bs * token_num_per_seq
|
||||
|
||||
tbo_split_seq_index, tbo_split_token_index = (
|
||||
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode=forward_mode,
|
||||
cuda_graph_num_tokens=num_tokens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
)
|
||||
|
||||
num_tokens_child_left = tbo_split_token_index
|
||||
num_tokens_child_right = num_tokens - tbo_split_token_index
|
||||
bs_child_left = num_tokens_child_left
|
||||
bs_child_right = num_tokens_child_right
|
||||
bs_child_left = tbo_split_seq_index
|
||||
bs_child_right = bs - bs_child_left
|
||||
|
||||
assert (
|
||||
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
||||
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[EagleVerifyInput],
|
||||
# capture args
|
||||
capture_num_tokens: int = None,
|
||||
# replay args
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
)
|
||||
assert encoder_lens is None, "encoder_lens is not supported yet"
|
||||
assert spec_info is None, "spec_info is not supported yet"
|
||||
if spec_info is not None:
|
||||
output_spec_info = two_batch_overlap.split_spec_info(
|
||||
spec_info=spec_info,
|
||||
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
|
||||
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
|
||||
start_token_index=(
|
||||
seq_slice.start * token_num_per_seq
|
||||
if seq_slice.start is not None
|
||||
else 0
|
||||
),
|
||||
end_token_index=(
|
||||
seq_slice.stop * token_num_per_seq
|
||||
if seq_slice.stop is not None
|
||||
else bs * token_num_per_seq
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
output_spec_info = None
|
||||
ans = dict(
|
||||
bs=output_bs,
|
||||
req_pool_indices=req_pool_indices[seq_slice],
|
||||
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
|
||||
forward_mode=forward_mode,
|
||||
# ignore
|
||||
encoder_lens=None,
|
||||
spec_info=None,
|
||||
spec_info=output_spec_info,
|
||||
)
|
||||
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
assert (
|
||||
capture_num_tokens == bs * token_num_per_seq
|
||||
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
|
||||
ans.update(
|
||||
dict(
|
||||
num_tokens=output_bs,
|
||||
num_tokens=output_bs * token_num_per_seq,
|
||||
)
|
||||
)
|
||||
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
||||
|
||||
Reference in New Issue
Block a user