Speed up when having padding tokens two-batch overlap (#6668)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -454,6 +454,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
|
||||
@@ -110,7 +110,7 @@ def compute_split_indices_for_cuda_graph_replay(
|
||||
|
||||
class TboCudaGraphRunnerPlugin:
|
||||
def __init__(self):
|
||||
pass # TODO add logic here
|
||||
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
||||
|
||||
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
||||
@@ -124,7 +124,14 @@ class TboCudaGraphRunnerPlugin:
|
||||
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
||||
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
||||
|
||||
TboForwardBatchPreparer.prepare(batch)
|
||||
self._tbo_children_num_token_non_padded[...] = (
|
||||
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
|
||||
)
|
||||
|
||||
TboForwardBatchPreparer.prepare_raw(
|
||||
batch,
|
||||
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
|
||||
)
|
||||
|
||||
def replay_prepare(
|
||||
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
|
||||
@@ -132,7 +139,20 @@ class TboCudaGraphRunnerPlugin:
|
||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
||||
return
|
||||
|
||||
pass # TODO add logic here
|
||||
tbo_split_seq_index, tbo_split_token_index = (
|
||||
compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode=forward_mode,
|
||||
# TODO support bs!=num_tokens
|
||||
cuda_graph_num_tokens=bs,
|
||||
)
|
||||
)
|
||||
|
||||
self._tbo_children_num_token_non_padded[...] = (
|
||||
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
|
||||
tbo_split_token_index=tbo_split_token_index,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TboDPAttentionPreparer:
|
||||
@@ -207,16 +227,23 @@ class TboDPAttentionPreparer:
|
||||
class TboForwardBatchPreparer:
|
||||
@classmethod
|
||||
def prepare(cls, batch: ForwardBatch):
|
||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||
|
||||
if batch.tbo_split_seq_index is None:
|
||||
return
|
||||
|
||||
tbo_split_token_index = compute_split_token_index(
|
||||
split_seq_index=batch.tbo_split_seq_index,
|
||||
forward_mode=batch.forward_mode,
|
||||
extend_seq_lens=batch.extend_seq_lens_cpu,
|
||||
tbo_children_num_token_non_padded = (
|
||||
cls.compute_tbo_children_num_token_non_padded(batch)
|
||||
)
|
||||
cls.prepare_raw(
|
||||
batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prepare_raw(
|
||||
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
|
||||
):
|
||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||
|
||||
tbo_split_token_index = cls._compute_split_token_index(batch)
|
||||
|
||||
if _tbo_debug:
|
||||
logger.info(
|
||||
@@ -229,6 +256,10 @@ class TboForwardBatchPreparer:
|
||||
assert isinstance(batch.attn_backend, TboAttnBackend)
|
||||
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
|
||||
|
||||
[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
|
||||
tbo_children_num_token_non_padded
|
||||
)
|
||||
|
||||
child_a = cls.filter_batch(
|
||||
batch,
|
||||
start_token_index=0,
|
||||
@@ -236,6 +267,7 @@ class TboForwardBatchPreparer:
|
||||
start_seq_index=0,
|
||||
end_seq_index=batch.tbo_split_seq_index,
|
||||
output_attn_backend=attn_backend_child_a,
|
||||
out_num_token_non_padded=out_num_token_non_padded_a,
|
||||
)
|
||||
child_b = cls.filter_batch(
|
||||
batch,
|
||||
@@ -244,6 +276,7 @@ class TboForwardBatchPreparer:
|
||||
start_seq_index=batch.tbo_split_seq_index,
|
||||
end_seq_index=batch.batch_size,
|
||||
output_attn_backend=attn_backend_child_b,
|
||||
out_num_token_non_padded=out_num_token_non_padded_b,
|
||||
)
|
||||
|
||||
assert batch.tbo_children is None
|
||||
@@ -259,9 +292,8 @@ class TboForwardBatchPreparer:
|
||||
start_seq_index: int,
|
||||
end_seq_index: int,
|
||||
output_attn_backend: AttentionBackend,
|
||||
out_num_token_non_padded: torch.Tensor,
|
||||
):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
num_tokens = batch.input_ids.shape[0]
|
||||
num_seqs = batch.batch_size
|
||||
|
||||
@@ -342,6 +374,7 @@ class TboForwardBatchPreparer:
|
||||
),
|
||||
extend_num_tokens=extend_num_tokens,
|
||||
attn_backend=output_attn_backend,
|
||||
num_token_non_padded=out_num_token_non_padded,
|
||||
tbo_split_seq_index=None,
|
||||
tbo_parent_token_range=(start_token_index, end_token_index),
|
||||
tbo_children=None,
|
||||
@@ -357,7 +390,6 @@ class TboForwardBatchPreparer:
|
||||
top_p_normalized_logprobs=False,
|
||||
top_p=None,
|
||||
mm_inputs=None,
|
||||
num_token_non_padded=None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -372,6 +404,32 @@ class TboForwardBatchPreparer:
|
||||
|
||||
return ForwardBatch(**output_dict)
|
||||
|
||||
@classmethod
|
||||
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
|
||||
return cls.compute_tbo_children_num_token_non_padded_raw(
|
||||
tbo_split_token_index=cls._compute_split_token_index(batch),
|
||||
num_token_non_padded=len(batch.input_ids),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def compute_tbo_children_num_token_non_padded_raw(
|
||||
cls, tbo_split_token_index: int, num_token_non_padded: int
|
||||
):
|
||||
# TODO we may make padding on both sub-batches to make it slightly more balanced
|
||||
value_a = min(tbo_split_token_index, num_token_non_padded)
|
||||
value_b = max(0, num_token_non_padded - tbo_split_token_index)
|
||||
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
|
||||
device=global_server_args_dict["device"], non_blocking=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _compute_split_token_index(cls, batch: ForwardBatch):
|
||||
return compute_split_token_index(
|
||||
split_seq_index=batch.tbo_split_seq_index,
|
||||
forward_mode=batch.forward_mode,
|
||||
extend_seq_lens=batch.extend_seq_lens_cpu,
|
||||
)
|
||||
|
||||
|
||||
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
||||
if forward_mode.is_extend():
|
||||
|
||||
Reference in New Issue
Block a user