feat: support compatibility between MTP and two-batch-overlap (#7225)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
)
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
num_tokens = bs
|
||||
assert (
|
||||
capture_num_tokens == bs * token_num_per_seq
|
||||
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
|
||||
num_tokens = bs * token_num_per_seq
|
||||
|
||||
tbo_split_seq_index, tbo_split_token_index = (
|
||||
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode=forward_mode,
|
||||
cuda_graph_num_tokens=num_tokens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
)
|
||||
|
||||
num_tokens_child_left = tbo_split_token_index
|
||||
num_tokens_child_right = num_tokens - tbo_split_token_index
|
||||
bs_child_left = num_tokens_child_left
|
||||
bs_child_right = num_tokens_child_right
|
||||
bs_child_left = tbo_split_seq_index
|
||||
bs_child_right = bs - bs_child_left
|
||||
|
||||
assert (
|
||||
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
||||
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[EagleVerifyInput],
|
||||
# capture args
|
||||
capture_num_tokens: int = None,
|
||||
# replay args
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
)
|
||||
assert encoder_lens is None, "encoder_lens is not supported yet"
|
||||
assert spec_info is None, "spec_info is not supported yet"
|
||||
if spec_info is not None:
|
||||
output_spec_info = two_batch_overlap.split_spec_info(
|
||||
spec_info=spec_info,
|
||||
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
|
||||
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
|
||||
start_token_index=(
|
||||
seq_slice.start * token_num_per_seq
|
||||
if seq_slice.start is not None
|
||||
else 0
|
||||
),
|
||||
end_token_index=(
|
||||
seq_slice.stop * token_num_per_seq
|
||||
if seq_slice.stop is not None
|
||||
else bs * token_num_per_seq
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
output_spec_info = None
|
||||
ans = dict(
|
||||
bs=output_bs,
|
||||
req_pool_indices=req_pool_indices[seq_slice],
|
||||
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
|
||||
forward_mode=forward_mode,
|
||||
# ignore
|
||||
encoder_lens=None,
|
||||
spec_info=None,
|
||||
spec_info=output_spec_info,
|
||||
)
|
||||
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
assert (
|
||||
capture_num_tokens == bs * token_num_per_seq
|
||||
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
|
||||
ans.update(
|
||||
dict(
|
||||
num_tokens=output_bs,
|
||||
num_tokens=output_bs * token_num_per_seq,
|
||||
)
|
||||
)
|
||||
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
||||
|
||||
@@ -679,6 +679,7 @@ class CudaGraphRunner:
|
||||
forward_mode=self.capture_forward_mode,
|
||||
bs=bs,
|
||||
num_token_non_padded=len(forward_batch.input_ids),
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
||||
forward_batch.spec_info.custom_mask = self.custom_mask
|
||||
|
||||
@@ -352,7 +352,9 @@ class ForwardBatch:
|
||||
|
||||
if ret.forward_mode.is_idle():
|
||||
ret.positions = torch.empty((0,), device=device)
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
TboForwardBatchPreparer.prepare(
|
||||
ret, is_draft_worker=model_runner.is_draft_worker
|
||||
)
|
||||
return ret
|
||||
|
||||
# Override the positions with spec_info
|
||||
@@ -397,7 +399,9 @@ class ForwardBatch:
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
TboForwardBatchPreparer.prepare(
|
||||
ret, is_draft_worker=model_runner.is_draft_worker
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1039,7 +1039,7 @@ class ModelRunner:
|
||||
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.enable_two_batch_overlap:
|
||||
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
|
||||
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
||||
else:
|
||||
self.attn_backend = self._get_attention_backend()
|
||||
|
||||
@@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
|
||||
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
return _compute_moe_deepseek_blog_prefill(layer)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
elif (
|
||||
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
||||
):
|
||||
return _compute_moe_deepseek_blog_decode(layer)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
||||
@@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
|
||||
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
return _compute_moe_qwen3_prefill(layer)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
elif (
|
||||
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
||||
):
|
||||
return _compute_moe_qwen3_decode(layer)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -16,6 +17,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||
|
||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||
@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
|
||||
# -------------------------------- Compute Basic Info ---------------------------------------
|
||||
|
||||
|
||||
def get_token_num_per_seq(
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
):
|
||||
if forward_mode.is_target_verify():
|
||||
return spec_info.draft_token_num
|
||||
elif forward_mode.is_decode():
|
||||
return 1
|
||||
elif forward_mode.is_idle():
|
||||
return 0
|
||||
else:
|
||||
# For extend, we should not use `token_num_per_seq`.
|
||||
return None
|
||||
|
||||
|
||||
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
||||
def compute_split_seq_index(
|
||||
forward_mode: "ForwardMode",
|
||||
num_tokens: int,
|
||||
extend_lens: Optional[Sequence[int]],
|
||||
token_num_per_seq: Optional[int],
|
||||
) -> Optional[int]:
|
||||
if forward_mode.is_extend():
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
assert extend_lens is not None
|
||||
return _split_array_by_half_sum(extend_lens)
|
||||
elif forward_mode.is_decode():
|
||||
return num_tokens // 2
|
||||
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||
assert token_num_per_seq is not None
|
||||
return (num_tokens // token_num_per_seq) // 2
|
||||
elif forward_mode.is_idle():
|
||||
assert num_tokens == 0
|
||||
return 0
|
||||
@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||
return best_index
|
||||
|
||||
|
||||
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
||||
if seq_index == 0:
|
||||
return 0
|
||||
|
||||
offset = 0
|
||||
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
|
||||
for i in range(max_seq_len):
|
||||
offset += (
|
||||
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
|
||||
) * spec_info.draft_token_num
|
||||
return offset
|
||||
|
||||
|
||||
def split_spec_info(
|
||||
spec_info: Optional[EagleVerifyInput],
|
||||
start_seq_index: int,
|
||||
end_seq_index: int,
|
||||
start_token_index: int,
|
||||
end_token_index: int,
|
||||
):
|
||||
if spec_info is None:
|
||||
return None
|
||||
if spec_info.draft_token is not None:
|
||||
draft_token = spec_info.draft_token[start_token_index:end_token_index]
|
||||
else:
|
||||
draft_token = None
|
||||
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
|
||||
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
|
||||
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
|
||||
custom_mask_end = spec_info.custom_mask.shape[0]
|
||||
else:
|
||||
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
|
||||
|
||||
if custom_mask_end > custom_mask_start:
|
||||
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
|
||||
else:
|
||||
custom_mask = spec_info.custom_mask
|
||||
else:
|
||||
custom_mask = spec_info.custom_mask
|
||||
if spec_info.positions is not None:
|
||||
positions = spec_info.positions[start_token_index:end_token_index]
|
||||
else:
|
||||
positions = None
|
||||
if spec_info.retrive_index is not None:
|
||||
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
|
||||
else:
|
||||
retrive_index = None
|
||||
if spec_info.retrive_next_token is not None:
|
||||
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
|
||||
else:
|
||||
retrive_next_token = None
|
||||
if spec_info.retrive_next_sibling is not None:
|
||||
retrive_next_sibling = spec_info.retrive_next_sibling[
|
||||
start_seq_index:end_seq_index
|
||||
]
|
||||
else:
|
||||
retrive_next_sibling = None
|
||||
if spec_info.retrive_cum_len is not None:
|
||||
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
|
||||
else:
|
||||
retrive_cum_len = None
|
||||
|
||||
if spec_info.seq_lens_cpu is not None:
|
||||
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
|
||||
else:
|
||||
seq_lens_cpu = None
|
||||
if seq_lens_cpu is not None:
|
||||
seq_lens_sum = seq_lens_cpu.sum()
|
||||
else:
|
||||
seq_lens_sum = None
|
||||
output_spec_info = replace(
|
||||
spec_info,
|
||||
custom_mask=custom_mask,
|
||||
draft_token=draft_token,
|
||||
positions=positions,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
retrive_cum_len=retrive_cum_len,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
)
|
||||
return output_spec_info
|
||||
|
||||
|
||||
def compute_split_token_index(
|
||||
split_seq_index: int,
|
||||
forward_mode: "ForwardMode",
|
||||
extend_seq_lens: Optional[Sequence[int]],
|
||||
token_num_per_seq: Optional[int],
|
||||
) -> int:
|
||||
if forward_mode.is_extend():
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
assert extend_seq_lens is not None
|
||||
return sum(extend_seq_lens[:split_seq_index])
|
||||
elif forward_mode.is_decode():
|
||||
return split_seq_index
|
||||
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||
assert token_num_per_seq is not None
|
||||
return split_seq_index * token_num_per_seq
|
||||
elif forward_mode.is_idle():
|
||||
assert split_seq_index == 0
|
||||
return 0
|
||||
@@ -83,19 +189,25 @@ def compute_split_token_index(
|
||||
def compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode: ForwardMode,
|
||||
cuda_graph_num_tokens: int,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
forward_mode_for_tbo_split = (
|
||||
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
||||
)
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
)
|
||||
tbo_split_seq_index = compute_split_seq_index(
|
||||
forward_mode=forward_mode_for_tbo_split,
|
||||
num_tokens=cuda_graph_num_tokens,
|
||||
extend_lens=None,
|
||||
token_num_per_seq=token_num_per_seq,
|
||||
)
|
||||
tbo_split_token_index = compute_split_token_index(
|
||||
split_seq_index=tbo_split_seq_index,
|
||||
forward_mode=forward_mode_for_tbo_split,
|
||||
extend_seq_lens=None,
|
||||
token_num_per_seq=token_num_per_seq,
|
||||
)
|
||||
return tbo_split_seq_index, tbo_split_token_index
|
||||
|
||||
@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
|
||||
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
||||
return
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
||||
)
|
||||
|
||||
batch.tbo_split_seq_index = compute_split_seq_index(
|
||||
forward_mode=batch.forward_mode,
|
||||
num_tokens=num_tokens,
|
||||
extend_lens=None,
|
||||
token_num_per_seq=token_num_per_seq,
|
||||
)
|
||||
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
||||
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
||||
@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
|
||||
)
|
||||
|
||||
def replay_prepare(
|
||||
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
|
||||
self,
|
||||
forward_mode: ForwardMode,
|
||||
bs: int,
|
||||
num_token_non_padded: int,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
)
|
||||
tbo_split_seq_index, tbo_split_token_index = (
|
||||
compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode=forward_mode,
|
||||
# TODO support bs!=num_tokens
|
||||
cuda_graph_num_tokens=bs,
|
||||
cuda_graph_num_tokens=bs * token_num_per_seq,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -154,14 +277,29 @@ class TboDPAttentionPreparer:
|
||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||
|
||||
if local_batch is not None:
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
|
||||
)
|
||||
|
||||
if (
|
||||
local_batch.forward_mode.is_target_verify()
|
||||
or local_batch.forward_mode.is_decode()
|
||||
):
|
||||
num_tokens = local_batch.batch_size() * token_num_per_seq
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
self.local_tbo_split_seq_index = compute_split_seq_index(
|
||||
forward_mode=local_batch.forward_mode,
|
||||
num_tokens=local_batch.input_ids.shape[0],
|
||||
num_tokens=num_tokens,
|
||||
extend_lens=local_batch.extend_lens,
|
||||
token_num_per_seq=token_num_per_seq,
|
||||
)
|
||||
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
||||
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
||||
local_batch.forward_mode.is_extend()
|
||||
(
|
||||
local_batch.forward_mode.is_extend()
|
||||
and not local_batch.forward_mode.is_target_verify()
|
||||
)
|
||||
and enable_deepep_moe
|
||||
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
||||
)
|
||||
@@ -218,8 +356,8 @@ class TboDPAttentionPreparer:
|
||||
|
||||
class TboForwardBatchPreparer:
|
||||
@classmethod
|
||||
def prepare(cls, batch: ForwardBatch):
|
||||
if batch.tbo_split_seq_index is None:
|
||||
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
|
||||
if batch.tbo_split_seq_index is None or is_draft_worker:
|
||||
return
|
||||
|
||||
tbo_children_num_token_non_padded = (
|
||||
@@ -242,7 +380,9 @@ class TboForwardBatchPreparer:
|
||||
f"TboForwardBatchPreparer.prepare "
|
||||
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
||||
f"tbo_split_token_index={tbo_split_token_index} "
|
||||
f"extend_seq_lens={batch.extend_seq_lens_cpu}"
|
||||
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
||||
f"bs={batch.batch_size} "
|
||||
f"forward_mode={batch.forward_mode}"
|
||||
)
|
||||
|
||||
assert isinstance(batch.attn_backend, TboAttnBackend)
|
||||
@@ -286,6 +426,9 @@ class TboForwardBatchPreparer:
|
||||
output_attn_backend: AttentionBackend,
|
||||
out_num_token_non_padded: torch.Tensor,
|
||||
):
|
||||
assert (
|
||||
end_token_index >= start_token_index
|
||||
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
|
||||
num_tokens = batch.input_ids.shape[0]
|
||||
num_seqs = batch.batch_size
|
||||
|
||||
@@ -317,11 +460,30 @@ class TboForwardBatchPreparer:
|
||||
old_value = getattr(batch, key)
|
||||
if old_value is None:
|
||||
continue
|
||||
elif batch.forward_mode.is_target_verify() and (
|
||||
key == "extend_seq_lens"
|
||||
or key == "extend_prefix_lens"
|
||||
or key == "extend_start_loc"
|
||||
or key == "extend_prefix_lens_cpu"
|
||||
or key == "extend_seq_lens_cpu"
|
||||
or key == "extend_logprob_start_lens_cpu"
|
||||
):
|
||||
output_dict[key] = None
|
||||
continue
|
||||
assert (
|
||||
len(old_value) == num_seqs
|
||||
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
||||
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
||||
|
||||
spec_info = getattr(batch, "spec_info")
|
||||
output_spec_info = split_spec_info(
|
||||
spec_info=spec_info,
|
||||
start_token_index=start_token_index,
|
||||
end_token_index=end_token_index,
|
||||
start_seq_index=start_seq_index,
|
||||
end_seq_index=end_seq_index,
|
||||
)
|
||||
output_dict["spec_info"] = output_spec_info
|
||||
for key in [
|
||||
"forward_mode",
|
||||
"return_logprob",
|
||||
@@ -329,18 +491,17 @@ class TboForwardBatchPreparer:
|
||||
"token_to_kv_pool",
|
||||
"can_run_dp_cuda_graph",
|
||||
"global_forward_mode",
|
||||
"spec_info",
|
||||
"spec_algorithm",
|
||||
"capture_hidden_mode",
|
||||
"padded_static_len",
|
||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||
]:
|
||||
output_dict[key] = getattr(batch, key)
|
||||
|
||||
assert (
|
||||
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
||||
== batch.extend_num_tokens
|
||||
), f"{batch=}"
|
||||
if not batch.forward_mode.is_target_verify():
|
||||
assert (
|
||||
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
||||
== batch.extend_num_tokens
|
||||
), f"{batch=}"
|
||||
extend_num_tokens = _compute_extend_num_tokens(
|
||||
output_dict["input_ids"], output_dict["forward_mode"]
|
||||
)
|
||||
@@ -419,18 +580,26 @@ class TboForwardBatchPreparer:
|
||||
|
||||
@classmethod
|
||||
def _compute_split_token_index(cls, batch: ForwardBatch):
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
||||
)
|
||||
return compute_split_token_index(
|
||||
split_seq_index=batch.tbo_split_seq_index,
|
||||
forward_mode=batch.forward_mode,
|
||||
extend_seq_lens=batch.extend_seq_lens_cpu,
|
||||
token_num_per_seq=token_num_per_seq,
|
||||
)
|
||||
|
||||
|
||||
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
||||
if forward_mode.is_extend():
|
||||
return input_ids.shape[0]
|
||||
elif forward_mode.is_decode() or forward_mode.is_idle():
|
||||
if (
|
||||
forward_mode.is_decode()
|
||||
or forward_mode.is_idle()
|
||||
or forward_mode.is_target_verify()
|
||||
):
|
||||
return None
|
||||
elif forward_mode.is_extend():
|
||||
return input_ids.shape[0]
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user