feat: support compatibility between MTP and two-batch-overlap (#7225)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
|
|||||||
replay_seq_lens_sum: int = None,
|
replay_seq_lens_sum: int = None,
|
||||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
||||||
|
forward_mode=forward_mode, spec_info=spec_info
|
||||||
|
)
|
||||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
assert (
|
||||||
num_tokens = bs
|
capture_num_tokens == bs * token_num_per_seq
|
||||||
|
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
|
||||||
|
num_tokens = bs * token_num_per_seq
|
||||||
|
|
||||||
tbo_split_seq_index, tbo_split_token_index = (
|
tbo_split_seq_index, tbo_split_token_index = (
|
||||||
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
cuda_graph_num_tokens=num_tokens,
|
cuda_graph_num_tokens=num_tokens,
|
||||||
|
spec_info=spec_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
num_tokens_child_left = tbo_split_token_index
|
num_tokens_child_left = tbo_split_token_index
|
||||||
num_tokens_child_right = num_tokens - tbo_split_token_index
|
num_tokens_child_right = num_tokens - tbo_split_token_index
|
||||||
bs_child_left = num_tokens_child_left
|
bs_child_left = tbo_split_seq_index
|
||||||
bs_child_right = num_tokens_child_right
|
bs_child_right = bs - bs_child_left
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
||||||
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: "ForwardMode",
|
forward_mode: "ForwardMode",
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[EagleVerifyInput],
|
||||||
# capture args
|
# capture args
|
||||||
capture_num_tokens: int = None,
|
capture_num_tokens: int = None,
|
||||||
# replay args
|
# replay args
|
||||||
replay_seq_lens_sum: int = None,
|
replay_seq_lens_sum: int = None,
|
||||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
||||||
|
forward_mode=forward_mode, spec_info=spec_info
|
||||||
|
)
|
||||||
assert encoder_lens is None, "encoder_lens is not supported yet"
|
assert encoder_lens is None, "encoder_lens is not supported yet"
|
||||||
assert spec_info is None, "spec_info is not supported yet"
|
if spec_info is not None:
|
||||||
|
output_spec_info = two_batch_overlap.split_spec_info(
|
||||||
|
spec_info=spec_info,
|
||||||
|
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
|
||||||
|
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
|
||||||
|
start_token_index=(
|
||||||
|
seq_slice.start * token_num_per_seq
|
||||||
|
if seq_slice.start is not None
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
end_token_index=(
|
||||||
|
seq_slice.stop * token_num_per_seq
|
||||||
|
if seq_slice.stop is not None
|
||||||
|
else bs * token_num_per_seq
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
output_spec_info = None
|
||||||
ans = dict(
|
ans = dict(
|
||||||
bs=output_bs,
|
bs=output_bs,
|
||||||
req_pool_indices=req_pool_indices[seq_slice],
|
req_pool_indices=req_pool_indices[seq_slice],
|
||||||
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
|
|||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
# ignore
|
# ignore
|
||||||
encoder_lens=None,
|
encoder_lens=None,
|
||||||
spec_info=None,
|
spec_info=output_spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
assert (
|
||||||
|
capture_num_tokens == bs * token_num_per_seq
|
||||||
|
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
|
||||||
ans.update(
|
ans.update(
|
||||||
dict(
|
dict(
|
||||||
num_tokens=output_bs,
|
num_tokens=output_bs * token_num_per_seq,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
||||||
|
|||||||
@@ -679,6 +679,7 @@ class CudaGraphRunner:
|
|||||||
forward_mode=self.capture_forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
bs=bs,
|
bs=bs,
|
||||||
num_token_non_padded=len(forward_batch.input_ids),
|
num_token_non_padded=len(forward_batch.input_ids),
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
||||||
forward_batch.spec_info.custom_mask = self.custom_mask
|
forward_batch.spec_info.custom_mask = self.custom_mask
|
||||||
|
|||||||
@@ -352,7 +352,9 @@ class ForwardBatch:
|
|||||||
|
|
||||||
if ret.forward_mode.is_idle():
|
if ret.forward_mode.is_idle():
|
||||||
ret.positions = torch.empty((0,), device=device)
|
ret.positions = torch.empty((0,), device=device)
|
||||||
TboForwardBatchPreparer.prepare(ret)
|
TboForwardBatchPreparer.prepare(
|
||||||
|
ret, is_draft_worker=model_runner.is_draft_worker
|
||||||
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# Override the positions with spec_info
|
# Override the positions with spec_info
|
||||||
@@ -397,7 +399,9 @@ class ForwardBatch:
|
|||||||
if model_runner.server_args.lora_paths is not None:
|
if model_runner.server_args.lora_paths is not None:
|
||||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||||
|
|
||||||
TboForwardBatchPreparer.prepare(ret)
|
TboForwardBatchPreparer.prepare(
|
||||||
|
ret, is_draft_worker=model_runner.is_draft_worker
|
||||||
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
@@ -1039,7 +1039,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def init_attention_backend(self):
|
def init_attention_backend(self):
|
||||||
"""Init attention kernel backend."""
|
"""Init attention kernel backend."""
|
||||||
if self.server_args.enable_two_batch_overlap:
|
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
|
||||||
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
||||||
else:
|
else:
|
||||||
self.attn_backend = self._get_attention_backend()
|
self.attn_backend = self._get_attention_backend()
|
||||||
|
|||||||
@@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
|
|||||||
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
|
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
return _compute_moe_deepseek_blog_prefill(layer)
|
return _compute_moe_deepseek_blog_prefill(layer)
|
||||||
elif forward_mode == ForwardMode.DECODE:
|
elif (
|
||||||
|
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
||||||
|
):
|
||||||
return _compute_moe_deepseek_blog_decode(layer)
|
return _compute_moe_deepseek_blog_decode(layer)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
||||||
@@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
|
|||||||
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
|
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
return _compute_moe_qwen3_prefill(layer)
|
return _compute_moe_qwen3_prefill(layer)
|
||||||
elif forward_mode == ForwardMode.DECODE:
|
elif (
|
||||||
|
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
||||||
|
):
|
||||||
return _compute_moe_qwen3_decode(layer)
|
return _compute_moe_qwen3_decode(layer)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Sequence
|
from dataclasses import replace
|
||||||
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||||
from sglang.srt.operations_strategy import OperationsStrategy
|
from sglang.srt.operations_strategy import OperationsStrategy
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||||
|
|
||||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||||
@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
|
|||||||
# -------------------------------- Compute Basic Info ---------------------------------------
|
# -------------------------------- Compute Basic Info ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_num_per_seq(
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||||
|
):
|
||||||
|
if forward_mode.is_target_verify():
|
||||||
|
return spec_info.draft_token_num
|
||||||
|
elif forward_mode.is_decode():
|
||||||
|
return 1
|
||||||
|
elif forward_mode.is_idle():
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
# For extend, we should not use `token_num_per_seq`.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
||||||
def compute_split_seq_index(
|
def compute_split_seq_index(
|
||||||
forward_mode: "ForwardMode",
|
forward_mode: "ForwardMode",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
extend_lens: Optional[Sequence[int]],
|
extend_lens: Optional[Sequence[int]],
|
||||||
|
token_num_per_seq: Optional[int],
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
if forward_mode.is_extend():
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
assert extend_lens is not None
|
assert extend_lens is not None
|
||||||
return _split_array_by_half_sum(extend_lens)
|
return _split_array_by_half_sum(extend_lens)
|
||||||
elif forward_mode.is_decode():
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||||
return num_tokens // 2
|
assert token_num_per_seq is not None
|
||||||
|
return (num_tokens // token_num_per_seq) // 2
|
||||||
elif forward_mode.is_idle():
|
elif forward_mode.is_idle():
|
||||||
assert num_tokens == 0
|
assert num_tokens == 0
|
||||||
return 0
|
return 0
|
||||||
@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
|||||||
return best_index
|
return best_index
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
||||||
|
if seq_index == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
|
||||||
|
for i in range(max_seq_len):
|
||||||
|
offset += (
|
||||||
|
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
|
||||||
|
) * spec_info.draft_token_num
|
||||||
|
return offset
|
||||||
|
|
||||||
|
|
||||||
|
def split_spec_info(
|
||||||
|
spec_info: Optional[EagleVerifyInput],
|
||||||
|
start_seq_index: int,
|
||||||
|
end_seq_index: int,
|
||||||
|
start_token_index: int,
|
||||||
|
end_token_index: int,
|
||||||
|
):
|
||||||
|
if spec_info is None:
|
||||||
|
return None
|
||||||
|
if spec_info.draft_token is not None:
|
||||||
|
draft_token = spec_info.draft_token[start_token_index:end_token_index]
|
||||||
|
else:
|
||||||
|
draft_token = None
|
||||||
|
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
|
||||||
|
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
|
||||||
|
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
|
||||||
|
custom_mask_end = spec_info.custom_mask.shape[0]
|
||||||
|
else:
|
||||||
|
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
|
||||||
|
|
||||||
|
if custom_mask_end > custom_mask_start:
|
||||||
|
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
|
||||||
|
else:
|
||||||
|
custom_mask = spec_info.custom_mask
|
||||||
|
else:
|
||||||
|
custom_mask = spec_info.custom_mask
|
||||||
|
if spec_info.positions is not None:
|
||||||
|
positions = spec_info.positions[start_token_index:end_token_index]
|
||||||
|
else:
|
||||||
|
positions = None
|
||||||
|
if spec_info.retrive_index is not None:
|
||||||
|
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
|
||||||
|
else:
|
||||||
|
retrive_index = None
|
||||||
|
if spec_info.retrive_next_token is not None:
|
||||||
|
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
|
||||||
|
else:
|
||||||
|
retrive_next_token = None
|
||||||
|
if spec_info.retrive_next_sibling is not None:
|
||||||
|
retrive_next_sibling = spec_info.retrive_next_sibling[
|
||||||
|
start_seq_index:end_seq_index
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
retrive_next_sibling = None
|
||||||
|
if spec_info.retrive_cum_len is not None:
|
||||||
|
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
|
||||||
|
else:
|
||||||
|
retrive_cum_len = None
|
||||||
|
|
||||||
|
if spec_info.seq_lens_cpu is not None:
|
||||||
|
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
|
||||||
|
else:
|
||||||
|
seq_lens_cpu = None
|
||||||
|
if seq_lens_cpu is not None:
|
||||||
|
seq_lens_sum = seq_lens_cpu.sum()
|
||||||
|
else:
|
||||||
|
seq_lens_sum = None
|
||||||
|
output_spec_info = replace(
|
||||||
|
spec_info,
|
||||||
|
custom_mask=custom_mask,
|
||||||
|
draft_token=draft_token,
|
||||||
|
positions=positions,
|
||||||
|
retrive_index=retrive_index,
|
||||||
|
retrive_next_token=retrive_next_token,
|
||||||
|
retrive_next_sibling=retrive_next_sibling,
|
||||||
|
retrive_cum_len=retrive_cum_len,
|
||||||
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
|
seq_lens_sum=seq_lens_sum,
|
||||||
|
)
|
||||||
|
return output_spec_info
|
||||||
|
|
||||||
|
|
||||||
def compute_split_token_index(
|
def compute_split_token_index(
|
||||||
split_seq_index: int,
|
split_seq_index: int,
|
||||||
forward_mode: "ForwardMode",
|
forward_mode: "ForwardMode",
|
||||||
extend_seq_lens: Optional[Sequence[int]],
|
extend_seq_lens: Optional[Sequence[int]],
|
||||||
|
token_num_per_seq: Optional[int],
|
||||||
) -> int:
|
) -> int:
|
||||||
if forward_mode.is_extend():
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
assert extend_seq_lens is not None
|
assert extend_seq_lens is not None
|
||||||
return sum(extend_seq_lens[:split_seq_index])
|
return sum(extend_seq_lens[:split_seq_index])
|
||||||
elif forward_mode.is_decode():
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||||
return split_seq_index
|
assert token_num_per_seq is not None
|
||||||
|
return split_seq_index * token_num_per_seq
|
||||||
elif forward_mode.is_idle():
|
elif forward_mode.is_idle():
|
||||||
assert split_seq_index == 0
|
assert split_seq_index == 0
|
||||||
return 0
|
return 0
|
||||||
@@ -83,19 +189,25 @@ def compute_split_token_index(
|
|||||||
def compute_split_indices_for_cuda_graph_replay(
|
def compute_split_indices_for_cuda_graph_replay(
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
cuda_graph_num_tokens: int,
|
cuda_graph_num_tokens: int,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
forward_mode_for_tbo_split = (
|
forward_mode_for_tbo_split = (
|
||||||
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
||||||
)
|
)
|
||||||
|
token_num_per_seq = get_token_num_per_seq(
|
||||||
|
forward_mode=forward_mode, spec_info=spec_info
|
||||||
|
)
|
||||||
tbo_split_seq_index = compute_split_seq_index(
|
tbo_split_seq_index = compute_split_seq_index(
|
||||||
forward_mode=forward_mode_for_tbo_split,
|
forward_mode=forward_mode_for_tbo_split,
|
||||||
num_tokens=cuda_graph_num_tokens,
|
num_tokens=cuda_graph_num_tokens,
|
||||||
extend_lens=None,
|
extend_lens=None,
|
||||||
|
token_num_per_seq=token_num_per_seq,
|
||||||
)
|
)
|
||||||
tbo_split_token_index = compute_split_token_index(
|
tbo_split_token_index = compute_split_token_index(
|
||||||
split_seq_index=tbo_split_seq_index,
|
split_seq_index=tbo_split_seq_index,
|
||||||
forward_mode=forward_mode_for_tbo_split,
|
forward_mode=forward_mode_for_tbo_split,
|
||||||
extend_seq_lens=None,
|
extend_seq_lens=None,
|
||||||
|
token_num_per_seq=token_num_per_seq,
|
||||||
)
|
)
|
||||||
return tbo_split_seq_index, tbo_split_token_index
|
return tbo_split_seq_index, tbo_split_token_index
|
||||||
|
|
||||||
@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
|
|||||||
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
||||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
if not global_server_args_dict["enable_two_batch_overlap"]:
|
||||||
return
|
return
|
||||||
|
token_num_per_seq = get_token_num_per_seq(
|
||||||
|
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
||||||
|
)
|
||||||
|
|
||||||
batch.tbo_split_seq_index = compute_split_seq_index(
|
batch.tbo_split_seq_index = compute_split_seq_index(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
extend_lens=None,
|
extend_lens=None,
|
||||||
|
token_num_per_seq=token_num_per_seq,
|
||||||
)
|
)
|
||||||
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
||||||
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
||||||
@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def replay_prepare(
|
def replay_prepare(
|
||||||
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
|
self,
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
bs: int,
|
||||||
|
num_token_non_padded: int,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
|
token_num_per_seq = get_token_num_per_seq(
|
||||||
|
forward_mode=forward_mode, spec_info=spec_info
|
||||||
|
)
|
||||||
tbo_split_seq_index, tbo_split_token_index = (
|
tbo_split_seq_index, tbo_split_token_index = (
|
||||||
compute_split_indices_for_cuda_graph_replay(
|
compute_split_indices_for_cuda_graph_replay(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
# TODO support bs!=num_tokens
|
cuda_graph_num_tokens=bs * token_num_per_seq,
|
||||||
cuda_graph_num_tokens=bs,
|
spec_info=spec_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -154,14 +277,29 @@ class TboDPAttentionPreparer:
|
|||||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||||
|
|
||||||
if local_batch is not None:
|
if local_batch is not None:
|
||||||
|
token_num_per_seq = get_token_num_per_seq(
|
||||||
|
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
local_batch.forward_mode.is_target_verify()
|
||||||
|
or local_batch.forward_mode.is_decode()
|
||||||
|
):
|
||||||
|
num_tokens = local_batch.batch_size() * token_num_per_seq
|
||||||
|
else:
|
||||||
|
num_tokens = local_batch.extend_num_tokens
|
||||||
self.local_tbo_split_seq_index = compute_split_seq_index(
|
self.local_tbo_split_seq_index = compute_split_seq_index(
|
||||||
forward_mode=local_batch.forward_mode,
|
forward_mode=local_batch.forward_mode,
|
||||||
num_tokens=local_batch.input_ids.shape[0],
|
num_tokens=num_tokens,
|
||||||
extend_lens=local_batch.extend_lens,
|
extend_lens=local_batch.extend_lens,
|
||||||
|
token_num_per_seq=token_num_per_seq,
|
||||||
)
|
)
|
||||||
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
||||||
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
||||||
local_batch.forward_mode.is_extend()
|
(
|
||||||
|
local_batch.forward_mode.is_extend()
|
||||||
|
and not local_batch.forward_mode.is_target_verify()
|
||||||
|
)
|
||||||
and enable_deepep_moe
|
and enable_deepep_moe
|
||||||
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
||||||
)
|
)
|
||||||
@@ -218,8 +356,8 @@ class TboDPAttentionPreparer:
|
|||||||
|
|
||||||
class TboForwardBatchPreparer:
|
class TboForwardBatchPreparer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare(cls, batch: ForwardBatch):
|
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
|
||||||
if batch.tbo_split_seq_index is None:
|
if batch.tbo_split_seq_index is None or is_draft_worker:
|
||||||
return
|
return
|
||||||
|
|
||||||
tbo_children_num_token_non_padded = (
|
tbo_children_num_token_non_padded = (
|
||||||
@@ -242,7 +380,9 @@ class TboForwardBatchPreparer:
|
|||||||
f"TboForwardBatchPreparer.prepare "
|
f"TboForwardBatchPreparer.prepare "
|
||||||
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
||||||
f"tbo_split_token_index={tbo_split_token_index} "
|
f"tbo_split_token_index={tbo_split_token_index} "
|
||||||
f"extend_seq_lens={batch.extend_seq_lens_cpu}"
|
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
||||||
|
f"bs={batch.batch_size} "
|
||||||
|
f"forward_mode={batch.forward_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(batch.attn_backend, TboAttnBackend)
|
assert isinstance(batch.attn_backend, TboAttnBackend)
|
||||||
@@ -286,6 +426,9 @@ class TboForwardBatchPreparer:
|
|||||||
output_attn_backend: AttentionBackend,
|
output_attn_backend: AttentionBackend,
|
||||||
out_num_token_non_padded: torch.Tensor,
|
out_num_token_non_padded: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
assert (
|
||||||
|
end_token_index >= start_token_index
|
||||||
|
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
|
||||||
num_tokens = batch.input_ids.shape[0]
|
num_tokens = batch.input_ids.shape[0]
|
||||||
num_seqs = batch.batch_size
|
num_seqs = batch.batch_size
|
||||||
|
|
||||||
@@ -317,11 +460,30 @@ class TboForwardBatchPreparer:
|
|||||||
old_value = getattr(batch, key)
|
old_value = getattr(batch, key)
|
||||||
if old_value is None:
|
if old_value is None:
|
||||||
continue
|
continue
|
||||||
|
elif batch.forward_mode.is_target_verify() and (
|
||||||
|
key == "extend_seq_lens"
|
||||||
|
or key == "extend_prefix_lens"
|
||||||
|
or key == "extend_start_loc"
|
||||||
|
or key == "extend_prefix_lens_cpu"
|
||||||
|
or key == "extend_seq_lens_cpu"
|
||||||
|
or key == "extend_logprob_start_lens_cpu"
|
||||||
|
):
|
||||||
|
output_dict[key] = None
|
||||||
|
continue
|
||||||
assert (
|
assert (
|
||||||
len(old_value) == num_seqs
|
len(old_value) == num_seqs
|
||||||
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
||||||
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
||||||
|
|
||||||
|
spec_info = getattr(batch, "spec_info")
|
||||||
|
output_spec_info = split_spec_info(
|
||||||
|
spec_info=spec_info,
|
||||||
|
start_token_index=start_token_index,
|
||||||
|
end_token_index=end_token_index,
|
||||||
|
start_seq_index=start_seq_index,
|
||||||
|
end_seq_index=end_seq_index,
|
||||||
|
)
|
||||||
|
output_dict["spec_info"] = output_spec_info
|
||||||
for key in [
|
for key in [
|
||||||
"forward_mode",
|
"forward_mode",
|
||||||
"return_logprob",
|
"return_logprob",
|
||||||
@@ -329,18 +491,17 @@ class TboForwardBatchPreparer:
|
|||||||
"token_to_kv_pool",
|
"token_to_kv_pool",
|
||||||
"can_run_dp_cuda_graph",
|
"can_run_dp_cuda_graph",
|
||||||
"global_forward_mode",
|
"global_forward_mode",
|
||||||
"spec_info",
|
|
||||||
"spec_algorithm",
|
"spec_algorithm",
|
||||||
"capture_hidden_mode",
|
"capture_hidden_mode",
|
||||||
"padded_static_len",
|
"padded_static_len",
|
||||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||||
]:
|
]:
|
||||||
output_dict[key] = getattr(batch, key)
|
output_dict[key] = getattr(batch, key)
|
||||||
|
if not batch.forward_mode.is_target_verify():
|
||||||
assert (
|
assert (
|
||||||
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
||||||
== batch.extend_num_tokens
|
== batch.extend_num_tokens
|
||||||
), f"{batch=}"
|
), f"{batch=}"
|
||||||
extend_num_tokens = _compute_extend_num_tokens(
|
extend_num_tokens = _compute_extend_num_tokens(
|
||||||
output_dict["input_ids"], output_dict["forward_mode"]
|
output_dict["input_ids"], output_dict["forward_mode"]
|
||||||
)
|
)
|
||||||
@@ -419,18 +580,26 @@ class TboForwardBatchPreparer:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _compute_split_token_index(cls, batch: ForwardBatch):
|
def _compute_split_token_index(cls, batch: ForwardBatch):
|
||||||
|
token_num_per_seq = get_token_num_per_seq(
|
||||||
|
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
||||||
|
)
|
||||||
return compute_split_token_index(
|
return compute_split_token_index(
|
||||||
split_seq_index=batch.tbo_split_seq_index,
|
split_seq_index=batch.tbo_split_seq_index,
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
extend_seq_lens=batch.extend_seq_lens_cpu,
|
extend_seq_lens=batch.extend_seq_lens_cpu,
|
||||||
|
token_num_per_seq=token_num_per_seq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
||||||
if forward_mode.is_extend():
|
if (
|
||||||
return input_ids.shape[0]
|
forward_mode.is_decode()
|
||||||
elif forward_mode.is_decode() or forward_mode.is_idle():
|
or forward_mode.is_idle()
|
||||||
|
or forward_mode.is_target_verify()
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
|
elif forward_mode.is_extend():
|
||||||
|
return input_ids.shape[0]
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -137,5 +137,86 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
|
|||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: enable this test later
|
||||||
|
# class TestDPAttentionDP2TP2DeepseekV3MTPTBO(CustomTestCase):
|
||||||
|
# @classmethod
|
||||||
|
# def setUpClass(cls):
|
||||||
|
# import os
|
||||||
|
|
||||||
|
# # print debug log for tbo
|
||||||
|
# os.environ["SGLANG_TBO_DEBUG"] = "1"
|
||||||
|
# cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||||
|
# cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
# other_args = [
|
||||||
|
# "--trust-remote-code",
|
||||||
|
# "--disable-radix",
|
||||||
|
# "--speculative-algorithm",
|
||||||
|
# "EAGLE",
|
||||||
|
# "--speculative-num-steps",
|
||||||
|
# "2",
|
||||||
|
# "--speculative-eagle-topk",
|
||||||
|
# "4",
|
||||||
|
# "--speculative-num-draft-tokens",
|
||||||
|
# "4",
|
||||||
|
# "--speculative-draft",
|
||||||
|
# DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
|
||||||
|
# "--tp-size",
|
||||||
|
# "2",
|
||||||
|
# "--enable-dp-attention",
|
||||||
|
# "--dp-size",
|
||||||
|
# "2",
|
||||||
|
# "--enable-two-batch-overlap",
|
||||||
|
# "--enable-deepep-moe",
|
||||||
|
# "--deepep-mode",
|
||||||
|
# "low_latency",
|
||||||
|
# "--chunked-prefill-size",
|
||||||
|
# "256",
|
||||||
|
# "--cuda-graph-max-bs",
|
||||||
|
# "32",
|
||||||
|
# "--max-running-requests",
|
||||||
|
# "32",
|
||||||
|
# ]
|
||||||
|
# if not is_in_amd_ci():
|
||||||
|
# other_args += ["--mem-frac", "0.7"]
|
||||||
|
# cls.process = popen_launch_server(
|
||||||
|
# cls.model,
|
||||||
|
# cls.base_url,
|
||||||
|
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
# other_args=other_args,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# def tearDownClass(cls):
|
||||||
|
# kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
# def test_gsm8k(self):
|
||||||
|
# requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
# args = SimpleNamespace(
|
||||||
|
# num_shots=5,
|
||||||
|
# data_path=None,
|
||||||
|
# num_questions=200,
|
||||||
|
# max_new_tokens=512,
|
||||||
|
# parallel=128,
|
||||||
|
# host="http://127.0.0.1",
|
||||||
|
# port=int(self.base_url.split(":")[-1]),
|
||||||
|
# )
|
||||||
|
# metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
# print(metrics)
|
||||||
|
|
||||||
|
# self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
# server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
# avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
# "avg_spec_accept_length"
|
||||||
|
# ]
|
||||||
|
# print(
|
||||||
|
# f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
|
||||||
|
# f"accuracy={metrics['accuracy']=:.3f}\n"
|
||||||
|
# f"{avg_spec_accept_length=:.3f}\n"
|
||||||
|
# )
|
||||||
|
# self.assertGreater(avg_spec_accept_length, 2.3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user