Support overlapping two batches (#4068)
This commit is contained in:
241
python/sglang/srt/layers/attention/tbo_backend.py
Normal file
241
python/sglang/srt/layers/attention/tbo_backend.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
|
||||
class TboAttnBackend(AttentionBackend):
|
||||
def __init__(self, primary: AttentionBackend, children: List[AttentionBackend]):
|
||||
super().__init__()
|
||||
self.primary = primary
|
||||
self.children = children
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, creator: Callable[[], AttentionBackend]):
|
||||
return cls(
|
||||
primary=creator(),
|
||||
children=[creator() for _ in range(2)],
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: "ForwardBatch"):
|
||||
self.primary.init_forward_metadata(forward_batch=forward_batch)
|
||||
if forward_batch.tbo_children is not None:
|
||||
for child, forward_batch_child in zip(
|
||||
self.children, forward_batch.tbo_children, strict=True
|
||||
):
|
||||
if forward_batch_child.batch_size > 0:
|
||||
child.init_forward_metadata(forward_batch=forward_batch_child)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.primary.init_cuda_graph_state(max_bs=max_bs)
|
||||
for item in self.children:
|
||||
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
||||
item.init_cuda_graph_state(max_bs=max_bs)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
self.primary.init_forward_metadata_capture_cuda_graph(
|
||||
bs=bs,
|
||||
num_tokens=num_tokens,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
encoder_lens=encoder_lens,
|
||||
forward_mode=forward_mode,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
|
||||
self._init_forward_metadata_cuda_graph_children(
|
||||
fn_name="init_forward_metadata_capture_cuda_graph",
|
||||
bs=bs,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
encoder_lens=encoder_lens,
|
||||
forward_mode=forward_mode,
|
||||
spec_info=spec_info,
|
||||
capture_num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
self.primary.init_forward_metadata_replay_cuda_graph(
|
||||
bs=bs,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
encoder_lens=encoder_lens,
|
||||
forward_mode=forward_mode,
|
||||
spec_info=spec_info,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
)
|
||||
|
||||
self._init_forward_metadata_cuda_graph_children(
|
||||
fn_name="init_forward_metadata_replay_cuda_graph",
|
||||
bs=bs,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
encoder_lens=encoder_lens,
|
||||
forward_mode=forward_mode,
|
||||
spec_info=spec_info,
|
||||
replay_seq_lens_sum=seq_lens_sum,
|
||||
replay_seq_lens_cpu=seq_lens_cpu,
|
||||
)
|
||||
|
||||
def _init_forward_metadata_cuda_graph_children(
|
||||
self,
|
||||
fn_name: str,
|
||||
# common args
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
# capture args
|
||||
capture_num_tokens: int = None,
|
||||
# replay args
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
num_tokens = bs
|
||||
|
||||
forward_mode_for_tbo_split = (
|
||||
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
||||
)
|
||||
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
|
||||
forward_mode=forward_mode_for_tbo_split,
|
||||
num_tokens=num_tokens,
|
||||
extend_lens=None,
|
||||
)
|
||||
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
|
||||
split_seq_index=tbo_split_seq_index,
|
||||
forward_mode=forward_mode_for_tbo_split,
|
||||
extend_seq_lens=None,
|
||||
)
|
||||
|
||||
num_tokens_child_left = tbo_split_token_index
|
||||
num_tokens_child_right = num_tokens - tbo_split_token_index
|
||||
bs_child_left = num_tokens_child_left
|
||||
bs_child_right = num_tokens_child_right
|
||||
|
||||
assert (
|
||||
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
||||
), f"{num_tokens_child_left=} {num_tokens_child_right=} {forward_mode=} {num_tokens=}"
|
||||
|
||||
common_pre_split_args = dict(
|
||||
fn_name=fn_name,
|
||||
bs=bs,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
encoder_lens=encoder_lens,
|
||||
forward_mode=forward_mode,
|
||||
spec_info=spec_info,
|
||||
capture_num_tokens=capture_num_tokens,
|
||||
replay_seq_lens_sum=replay_seq_lens_sum,
|
||||
replay_seq_lens_cpu=replay_seq_lens_cpu,
|
||||
)
|
||||
|
||||
args_left = _init_forward_metadata_cuda_graph_split(
|
||||
output_bs=bs_child_left,
|
||||
seq_slice=slice(None, tbo_split_seq_index),
|
||||
**common_pre_split_args,
|
||||
)
|
||||
args_right = _init_forward_metadata_cuda_graph_split(
|
||||
output_bs=bs_child_right,
|
||||
seq_slice=slice(tbo_split_seq_index, None),
|
||||
**common_pre_split_args,
|
||||
)
|
||||
|
||||
child_left, child_right = self.children
|
||||
getattr(child_left, fn_name)(**args_left)
|
||||
getattr(child_right, fn_name)(**args_right)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
ans = self.primary.get_cuda_graph_seq_len_fill_value()
|
||||
for child in self.children:
|
||||
assert ans == child.get_cuda_graph_seq_len_fill_value()
|
||||
return ans
|
||||
|
||||
def forward_extend(self, *args, **kwargs):
|
||||
return self.primary.forward_extend(*args, **kwargs)
|
||||
|
||||
def forward_decode(self, *args, **kwargs):
|
||||
return self.primary.forward_decode(*args, **kwargs)
|
||||
|
||||
|
||||
def _init_forward_metadata_cuda_graph_split(
|
||||
fn_name: str,
|
||||
seq_slice: slice,
|
||||
output_bs: int,
|
||||
# common args
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
# capture args
|
||||
capture_num_tokens: int = None,
|
||||
# replay args
|
||||
replay_seq_lens_sum: int = None,
|
||||
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert encoder_lens is None, "encoder_lens is not supported yet"
|
||||
assert spec_info is None, "spec_info is not supported yet"
|
||||
|
||||
ans = dict(
|
||||
bs=output_bs,
|
||||
req_pool_indices=req_pool_indices[seq_slice],
|
||||
seq_lens=seq_lens[seq_slice],
|
||||
# directly forward
|
||||
forward_mode=forward_mode,
|
||||
# ignore
|
||||
encoder_lens=None,
|
||||
spec_info=None,
|
||||
)
|
||||
|
||||
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
||||
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
|
||||
ans.update(
|
||||
dict(
|
||||
num_tokens=output_bs,
|
||||
)
|
||||
)
|
||||
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
||||
output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice]
|
||||
ans.update(
|
||||
dict(
|
||||
seq_lens_sum=output_seq_lens_cpu.sum().item(),
|
||||
seq_lens_cpu=output_seq_lens_cpu,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return ans
|
||||
@@ -391,3 +391,16 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||
RuntimeCache.get = __patched_func
|
||||
yield
|
||||
RuntimeCache.get = origin_func
|
||||
|
||||
|
||||
@contextmanager
|
||||
def configure_deep_gemm_num_sms(num_sms):
|
||||
if num_sms is None:
|
||||
yield
|
||||
else:
|
||||
original_num_sms = deep_gemm.get_num_sms()
|
||||
deep_gemm.set_num_sms(num_sms)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deep_gemm.set_num_sms(original_num_sms)
|
||||
|
||||
@@ -78,6 +78,7 @@ global_server_args_dict = {
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap,
|
||||
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"deepep_config": ServerArgs.deepep_config,
|
||||
@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = self.seq_lens.cpu()
|
||||
else:
|
||||
@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
tbo_split_seq_index=self.tbo_split_seq_index,
|
||||
global_forward_mode=self.global_forward_mode,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
tbo_split_seq_index: Optional[int]
|
||||
global_forward_mode: Optional[ForwardMode]
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
|
||||
@@ -34,6 +34,7 @@ import zmq
|
||||
from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.disaggregation.decode import (
|
||||
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
@@ -1648,6 +1651,9 @@ class Scheduler(
|
||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
||||
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
||||
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1661,6 +1667,9 @@ class Scheduler(
|
||||
disable_cuda_graph: bool,
|
||||
spec_algorithm,
|
||||
speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap: bool,
|
||||
enable_deepep_moe: bool,
|
||||
deepep_mode: DeepEPMode,
|
||||
):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
@@ -1696,17 +1705,26 @@ class Scheduler(
|
||||
is_extend_in_batch = (
|
||||
local_batch.forward_mode.is_extend() if local_batch else False
|
||||
)
|
||||
|
||||
tbo_preparer = TboDPAttentionPreparer()
|
||||
|
||||
local_info = torch.tensor(
|
||||
[
|
||||
num_tokens,
|
||||
can_cuda_graph,
|
||||
num_tokens_for_logprob,
|
||||
is_extend_in_batch,
|
||||
*tbo_preparer.prepare_all_gather(
|
||||
local_batch,
|
||||
deepep_mode,
|
||||
enable_deepep_moe,
|
||||
enable_two_batch_overlap,
|
||||
),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
global_info = torch.empty(
|
||||
(dp_size, attn_tp_size, 4),
|
||||
(dp_size, attn_tp_size, 6),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
@@ -1719,6 +1737,10 @@ class Scheduler(
|
||||
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
||||
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
||||
|
||||
tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
|
||||
global_info[:, :, 4:6]
|
||||
)
|
||||
|
||||
if local_batch is None and max(global_num_tokens) > 0:
|
||||
local_batch = get_idle_batch()
|
||||
|
||||
@@ -1732,6 +1754,8 @@ class Scheduler(
|
||||
local_batch.global_num_tokens_for_logprob = (
|
||||
global_num_tokens_for_logprob
|
||||
)
|
||||
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
||||
local_batch.global_forward_mode = global_forward_mode
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not disable_cuda_graph:
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
TboCudaGraphRunnerUtils,
|
||||
TboForwardBatchPreparer,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
get_device_memory_capacity,
|
||||
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
model_runner.req_to_token_pool.size
|
||||
]
|
||||
|
||||
if server_args.enable_two_batch_overlap:
|
||||
capture_bs = [bs for bs in capture_bs if bs >= 2]
|
||||
|
||||
if server_args.cuda_graph_max_bs:
|
||||
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
||||
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
||||
@@ -349,7 +357,14 @@ class CudaGraphRunner:
|
||||
if self.is_encoder_decoder
|
||||
else True
|
||||
)
|
||||
return is_bs_supported and is_encoder_lens_supported
|
||||
|
||||
is_tbo_supported = (
|
||||
forward_batch.can_run_tbo
|
||||
if self.model_runner.server_args.enable_two_batch_overlap
|
||||
else True
|
||||
)
|
||||
|
||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
@@ -466,7 +481,12 @@ class CudaGraphRunner:
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
lora_paths=lora_paths,
|
||||
num_token_non_padded=self.num_token_non_padded,
|
||||
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
|
||||
self, num_tokens
|
||||
),
|
||||
global_forward_mode=self.capture_forward_mode,
|
||||
)
|
||||
TboForwardBatchPreparer.prepare(forward_batch)
|
||||
|
||||
if lora_paths is not None:
|
||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -239,6 +240,7 @@ class ForwardBatch:
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
@@ -252,12 +254,18 @@ class ForwardBatch:
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
||||
tbo_children: Optional[List["ForwardBatch"]] = None
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
|
||||
|
||||
device = model_runner.device
|
||||
extend_input_logprob_token_ids_gpu = None
|
||||
if batch.extend_input_logprob_token_ids is not None:
|
||||
@@ -281,6 +289,7 @@ class ForwardBatch:
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
global_forward_mode=batch.global_forward_mode,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
@@ -294,6 +303,7 @@ class ForwardBatch:
|
||||
num_token_non_padded=torch.tensor(
|
||||
len(batch.input_ids), dtype=torch.int32
|
||||
).to(device, non_blocking=True),
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
)
|
||||
|
||||
# For DP attention
|
||||
@@ -316,6 +326,7 @@ class ForwardBatch:
|
||||
)
|
||||
if ret.forward_mode.is_idle():
|
||||
ret.positions = torch.empty((0,), device=device)
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
return ret
|
||||
|
||||
# Override the positions with spec_info
|
||||
@@ -364,6 +375,8 @@ class ForwardBatch:
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
|
||||
return ret
|
||||
|
||||
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
||||
@@ -588,6 +601,10 @@ class ForwardBatch:
|
||||
# Precompute the kv indices for each chunk
|
||||
self.prepare_chunked_kv_indices(device)
|
||||
|
||||
@property
|
||||
def can_run_tbo(self):
|
||||
return self.tbo_split_seq_index is not None
|
||||
|
||||
|
||||
class PPProxyTensors:
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
get_attention_tp_size,
|
||||
@@ -198,6 +199,7 @@ class ModelRunner:
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_two_batch_overlap": server_args.enable_two_batch_overlap,
|
||||
"enable_dp_lm_head": server_args.enable_dp_lm_head,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||
@@ -994,6 +996,13 @@ class ModelRunner:
|
||||
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.enable_two_batch_overlap:
|
||||
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
||||
else:
|
||||
self.attn_backend = self._get_attention_backend()
|
||||
|
||||
# TODO unify with 6338
|
||||
def _get_attention_backend(self):
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
if not self.use_mla_backend:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
@@ -1003,17 +1012,17 @@ class ModelRunner:
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
return FlashInferAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
return FlashInferMLAAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "aiter":
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
|
||||
self.attn_backend = AiterAttnBackend(self)
|
||||
return AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
@@ -1028,21 +1037,21 @@ class ModelRunner:
|
||||
DoubleSparseAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
return DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
return TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
from sglang.srt.layers.attention.torch_native_backend import (
|
||||
TorchNativeAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
return TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||
|
||||
self.attn_backend = FlashMLABackend(self)
|
||||
return FlashMLABackend(self)
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
assert (
|
||||
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
||||
@@ -1054,13 +1063,13 @@ class ModelRunner:
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashAttentionBackend(self)
|
||||
return FlashAttentionBackend(self)
|
||||
elif self.server_args.attention_backend == "cutlass_mla":
|
||||
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
||||
CutlassMLABackend,
|
||||
)
|
||||
|
||||
self.attn_backend = CutlassMLABackend(self)
|
||||
return CutlassMLABackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.operations import execute_operations
|
||||
from sglang.srt.operations_strategy import compute_layer_operations
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
else None
|
||||
)
|
||||
|
||||
self.deepep_dispatcher = DeepEPDispatcher(
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||
async_finish=True, # TODO
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _enable_deepep_moe(self):
|
||||
return global_server_args_dict["enable_deepep_moe"]
|
||||
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
return None
|
||||
|
||||
def op_gate(self, state):
|
||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||
if is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
@@ -432,115 +433,105 @@ class DeepseekV2MoE(nn.Module):
|
||||
state.router_logits = None
|
||||
|
||||
def op_shared_experts(self, state):
|
||||
if (self.n_share_experts_fusion == 0) and (
|
||||
(not self._enable_deepep_moe)
|
||||
or is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
)
|
||||
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
|
||||
if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, hidden_states_mlp_input
|
||||
):
|
||||
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
|
||||
state.shared_output = self.shared_experts(hidden_states_mlp_input)
|
||||
else:
|
||||
state.shared_output = None
|
||||
|
||||
def op_select_experts(self, state):
|
||||
router_logits = state.router_logits
|
||||
router_logits = state.pop("router_logits")
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
|
||||
if self._enable_deepep_moe:
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
self.deepep_dispatcher.dispatch_a(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
hidden_states=state.hidden_states_mlp_input,
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_dispatch_b(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b()
|
||||
if self.ep_size > 1:
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
if self._enable_deepep_moe:
|
||||
state.pop("router_logits")
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
else:
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
router_logits=state.pop("router_logits"),
|
||||
)
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
if self.ep_size > 1:
|
||||
self.deepep_dispatcher.combine_a(
|
||||
state.pop("hidden_states_experts_output"),
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
final_hidden_states = (
|
||||
state.pop("hidden_states_after_combine")
|
||||
if self._enable_deepep_moe
|
||||
else state.pop("hidden_states_experts_output")
|
||||
)
|
||||
|
||||
final_hidden_states = state.pop("hidden_states_after_combine")
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
if (s := state.pop("shared_output")) is not None:
|
||||
final_hidden_states = final_hidden_states + s
|
||||
|
||||
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
state.hidden_states_mlp_output = final_hidden_states
|
||||
|
||||
|
||||
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
tbo_subbatch_index: Optional[int] = None,
|
||||
):
|
||||
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
||||
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
||||
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
positions=positions,
|
||||
zero_allocator=zero_allocator,
|
||||
tbo_subbatch_index=tbo_subbatch_index,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
state.forward_batch,
|
||||
)
|
||||
|
||||
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
|
||||
return hidden_states, residual
|
||||
output = dict(
|
||||
positions=state.positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=state.forward_batch,
|
||||
zero_allocator=state.zero_allocator,
|
||||
tbo_subbatch_index=state.tbo_subbatch_index,
|
||||
)
|
||||
|
||||
state.clear(
|
||||
expect_keys={
|
||||
"positions",
|
||||
"forward_batch",
|
||||
"zero_allocator",
|
||||
"tbo_subbatch_index",
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
self.padding_id = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
total_num_layers = len(self.layers)
|
||||
device = input_embeds.device if input_embeds is not None else input_ids.device
|
||||
zero_allocator = BumpAllocator(
|
||||
# TODO for two-batch-overlap, we need a larger buffer size
|
||||
buffer_size=len(self.layers) * 2,
|
||||
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
||||
dtype=torch.float32,
|
||||
device=(
|
||||
input_embeds.device if input_embeds is not None else input_ids.device
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states = input_embeds
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
|
||||
normal_num_layers = (
|
||||
self.first_k_dense_replace
|
||||
if forward_batch.can_run_tbo
|
||||
else total_num_layers
|
||||
)
|
||||
for i in range(normal_num_layers):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
|
||||
if normal_num_layers != total_num_layers:
|
||||
hidden_states, residual = model_forward_maybe_tbo(
|
||||
layers=self.layers[normal_num_layers:],
|
||||
enable_tbo=True,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
|
||||
return self.logits_processor(
|
||||
|
||||
@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
|
||||
|
||||
|
||||
def execute_operations(inputs, operations):
|
||||
stages = _convert_operations_to_stages(decorate_operations(operations))
|
||||
stages = _convert_operations_to_stages(operations)
|
||||
executor = _StageExecutor("primary", stages, inputs=inputs)
|
||||
for _ in range(executor.num_stages):
|
||||
executor.next()
|
||||
@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
|
||||
return executor.output
|
||||
|
||||
|
||||
def execute_overlapped_operations(
|
||||
inputs_arr: Sequence,
|
||||
operations_arr: Sequence,
|
||||
delta_stages: Sequence[int],
|
||||
) -> Sequence:
|
||||
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
|
||||
inputs_a, inputs_b = inputs_arr
|
||||
operations_a, operations_b = operations_arr
|
||||
delta_stage_a, delta_stage_b = delta_stages
|
||||
assert delta_stage_a == 0
|
||||
delta_stage = delta_stage_b
|
||||
|
||||
stages_a = _convert_operations_to_stages(operations_a)
|
||||
stages_b = _convert_operations_to_stages(operations_b)
|
||||
executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
|
||||
executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
|
||||
|
||||
for _ in range(delta_stage):
|
||||
executor_a.next()
|
||||
|
||||
for _ in range(executor_a.num_stages - delta_stage):
|
||||
executor_a.next()
|
||||
executor_b.next()
|
||||
|
||||
for _ in range(delta_stage):
|
||||
executor_b.next()
|
||||
|
||||
assert executor_a.done and executor_b.done
|
||||
return [executor_a.output, executor_b.output]
|
||||
|
||||
|
||||
class YieldOperation:
|
||||
pass
|
||||
|
||||
@@ -109,6 +140,9 @@ class _StateDict:
|
||||
for k, v in values.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def get(self, item):
|
||||
return self._data.get(item)
|
||||
|
||||
def clear(self, expect_keys: Sequence[str]):
|
||||
if set(self._data.keys()) != set(expect_keys):
|
||||
raise Exception(
|
||||
@@ -119,6 +153,7 @@ class _StateDict:
|
||||
|
||||
|
||||
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
|
||||
operations = _decorate_operations(operations)
|
||||
operation_chunks = list(
|
||||
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
|
||||
)
|
||||
@@ -140,7 +175,7 @@ def _chunk_by_separator(
|
||||
yield pending_items
|
||||
|
||||
|
||||
def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
||||
def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
||||
return [_decorate_operation(op, debug_name_prefix) for op in operations]
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,116 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt import operations
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.operations import Operation
|
||||
|
||||
def compute_layer_operations(
|
||||
|
||||
@dataclass
|
||||
class OperationsStrategy:
|
||||
operations: List[Operation]
|
||||
deep_gemm_num_sms: Optional[int] = None
|
||||
tbo_delta_stages: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy":
|
||||
return OperationsStrategy(
|
||||
operations=[x for item in items for x in item.operations],
|
||||
deep_gemm_num_sms=_assert_all_same(
|
||||
[item.deep_gemm_num_sms for item in items]
|
||||
),
|
||||
tbo_delta_stages=_assert_all_same(
|
||||
[item.tbo_delta_stages for item in items]
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_new_tbo(
|
||||
layers: torch.nn.ModuleList,
|
||||
forward_mode: ForwardMode,
|
||||
) -> "OperationsStrategy":
|
||||
return OperationsStrategy.concat(
|
||||
[
|
||||
_compute_layer_operations_strategy_tbo(layer, forward_mode)
|
||||
for layer in layers
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _assert_all_same(items: List):
|
||||
assert all(item == items[0] for item in items)
|
||||
return items[0]
|
||||
|
||||
|
||||
# TODO can refactor to make it more fancy if we have more complex strategies
|
||||
def _compute_layer_operations_strategy_tbo(
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
if not layer.is_layer_sparse:
|
||||
return [
|
||||
forward_mode: ForwardMode,
|
||||
) -> OperationsStrategy:
|
||||
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
return _compute_moe_deepseek_blog_prefill(layer)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
return _compute_moe_deepseek_blog_decode(layer)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
||||
|
||||
|
||||
def _compute_moe_deepseek_blog_prefill(layer):
|
||||
device_properties = torch.cuda.get_device_properties(device="cuda")
|
||||
total_num_sms = device_properties.multi_processor_count
|
||||
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
|
||||
|
||||
return OperationsStrategy(
|
||||
deep_gemm_num_sms=deep_gemm_num_sms,
|
||||
tbo_delta_stages=0,
|
||||
operations=[
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.op_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_select_experts,
|
||||
layer.mlp.op_dispatch_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_dispatch_b,
|
||||
layer.mlp.op_experts,
|
||||
layer.mlp.op_combine_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_shared_experts,
|
||||
layer.mlp.op_combine_b,
|
||||
layer.mlp.op_output,
|
||||
layer.op_comm_postprocess_layer,
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# Will add TBO operation orders here
|
||||
return [
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_shared_experts,
|
||||
layer.mlp.op_select_experts,
|
||||
layer.mlp.op_dispatch_a,
|
||||
layer.mlp.op_dispatch_b,
|
||||
layer.mlp.op_experts,
|
||||
layer.mlp.op_combine_a,
|
||||
layer.mlp.op_combine_b,
|
||||
layer.mlp.op_output,
|
||||
layer.op_comm_postprocess_layer,
|
||||
]
|
||||
|
||||
def _compute_moe_deepseek_blog_decode(layer):
|
||||
return OperationsStrategy(
|
||||
deep_gemm_num_sms=None,
|
||||
tbo_delta_stages=2,
|
||||
operations=[
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
operations.YieldOperation(),
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_select_experts,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_dispatch_a,
|
||||
layer.mlp.op_shared_experts,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_dispatch_b,
|
||||
layer.mlp.op_experts,
|
||||
layer.mlp.op_combine_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_combine_b,
|
||||
layer.mlp.op_output,
|
||||
layer.op_comm_postprocess_layer,
|
||||
operations.YieldOperation(),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -167,6 +167,7 @@ class ServerArgs:
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_dp_lm_head: bool = False
|
||||
enable_two_batch_overlap: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
@@ -1144,6 +1145,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-two-batch-overlap",
|
||||
action="store_true",
|
||||
help="Enabling two micro batches to overlap.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-compile",
|
||||
action="store_true",
|
||||
|
||||
462
python/sglang/srt/two_batch_overlap.py
Normal file
462
python/sglang/srt/two_batch_overlap.py
Normal file
@@ -0,0 +1,462 @@
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
|
||||
# -------------------------------- Compute Basic Info ---------------------------------------
|
||||
|
||||
|
||||
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
||||
def compute_split_seq_index(
|
||||
forward_mode: "ForwardMode",
|
||||
num_tokens: int,
|
||||
extend_lens: Optional[Sequence[int]],
|
||||
) -> Optional[int]:
|
||||
if forward_mode.is_extend():
|
||||
assert extend_lens is not None
|
||||
return _split_array_by_half_sum(extend_lens)
|
||||
elif forward_mode.is_decode():
|
||||
return num_tokens // 2
|
||||
elif forward_mode.is_idle():
|
||||
assert num_tokens == 0
|
||||
return 0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||
overall_sum = sum(arr)
|
||||
accumulator, split_index = 0, 0
|
||||
for value in arr[:-1]:
|
||||
accumulator += value
|
||||
split_index += 1
|
||||
if accumulator >= overall_sum // 2:
|
||||
break
|
||||
return split_index
|
||||
|
||||
|
||||
def compute_split_token_index(
|
||||
split_seq_index: int,
|
||||
forward_mode: "ForwardMode",
|
||||
extend_seq_lens: Optional[Sequence[int]],
|
||||
) -> int:
|
||||
if forward_mode.is_extend():
|
||||
assert extend_seq_lens is not None
|
||||
return sum(extend_seq_lens[:split_seq_index])
|
||||
elif forward_mode.is_decode():
|
||||
return split_seq_index
|
||||
elif forward_mode.is_idle():
|
||||
assert split_seq_index == 0
|
||||
return 0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# -------------------------------- Preparation ---------------------------------------
|
||||
|
||||
|
||||
class TboCudaGraphRunnerUtils:
|
||||
@staticmethod
|
||||
def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int):
|
||||
if that.model_runner.server_args.enable_two_batch_overlap:
|
||||
tbo_split_seq_index = compute_split_seq_index(
|
||||
forward_mode=that.capture_forward_mode,
|
||||
num_tokens=num_tokens,
|
||||
extend_lens=None,
|
||||
)
|
||||
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
||||
assert (
|
||||
tbo_split_seq_index is not None
|
||||
), f"{that.capture_forward_mode=} {num_tokens=}"
|
||||
else:
|
||||
tbo_split_seq_index = None
|
||||
return tbo_split_seq_index
|
||||
|
||||
|
||||
class TboDPAttentionPreparer:
|
||||
def prepare_all_gather(
|
||||
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
|
||||
):
|
||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||
|
||||
if local_batch is not None:
|
||||
self.local_tbo_split_seq_index = compute_split_seq_index(
|
||||
forward_mode=local_batch.forward_mode,
|
||||
num_tokens=local_batch.input_ids.shape[0],
|
||||
extend_lens=local_batch.extend_lens,
|
||||
)
|
||||
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
||||
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
||||
local_batch.forward_mode.is_extend()
|
||||
and enable_deepep_moe
|
||||
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
||||
)
|
||||
else:
|
||||
self.local_tbo_split_seq_index = 0
|
||||
local_can_run_tbo = True
|
||||
|
||||
local_forward_mode = self._compute_local_forward_mode(local_batch)
|
||||
|
||||
return local_can_run_tbo, local_forward_mode
|
||||
|
||||
def compute_output(self, partial_global_info):
|
||||
local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist())
|
||||
forward_modes = partial_global_info[:, 0, 1].tolist()
|
||||
|
||||
global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
|
||||
forward_modes
|
||||
)
|
||||
|
||||
can_run_tbo = (
|
||||
self.enable_two_batch_overlap
|
||||
and local_can_run_tbo_aggregated
|
||||
and forward_mode_agree
|
||||
)
|
||||
|
||||
tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
|
||||
global_forward_mode = global_forward_mode if can_run_tbo else None
|
||||
return tbo_split_seq_index, global_forward_mode
|
||||
|
||||
@staticmethod
|
||||
def _compute_local_forward_mode(local_batch):
|
||||
return (
|
||||
local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
|
||||
).value
|
||||
|
||||
@staticmethod
|
||||
def _compute_global_forward_mode(forward_modes):
|
||||
converted_forward_modes = [
|
||||
ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x
|
||||
for x in forward_modes
|
||||
]
|
||||
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
|
||||
converted_forward_modes
|
||||
)
|
||||
global_forward_mode = (
|
||||
ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None
|
||||
)
|
||||
return global_forward_mode, forward_mode_agree
|
||||
|
||||
@staticmethod
|
||||
def _is_all_same(x):
|
||||
return all(value == x[0] for value in x)
|
||||
|
||||
|
||||
class TboForwardBatchPreparer:
|
||||
@classmethod
|
||||
def prepare(cls, batch: ForwardBatch):
|
||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||
|
||||
if batch.tbo_split_seq_index is None:
|
||||
return
|
||||
|
||||
tbo_split_token_index = compute_split_token_index(
|
||||
split_seq_index=batch.tbo_split_seq_index,
|
||||
forward_mode=batch.forward_mode,
|
||||
extend_seq_lens=batch.extend_seq_lens_cpu,
|
||||
)
|
||||
|
||||
assert isinstance(batch.attn_backend, TboAttnBackend)
|
||||
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
|
||||
|
||||
child_a = cls.filter_batch(
|
||||
batch,
|
||||
start_token_index=0,
|
||||
end_token_index=tbo_split_token_index,
|
||||
start_seq_index=0,
|
||||
end_seq_index=batch.tbo_split_seq_index,
|
||||
output_attn_backend=attn_backend_child_a,
|
||||
)
|
||||
child_b = cls.filter_batch(
|
||||
batch,
|
||||
start_token_index=tbo_split_token_index,
|
||||
end_token_index=batch.input_ids.shape[0],
|
||||
start_seq_index=batch.tbo_split_seq_index,
|
||||
end_seq_index=batch.batch_size,
|
||||
output_attn_backend=attn_backend_child_b,
|
||||
)
|
||||
|
||||
assert batch.tbo_children is None
|
||||
batch.tbo_children = [child_a, child_b]
|
||||
|
||||
@classmethod
|
||||
def filter_batch(
|
||||
cls,
|
||||
batch: ForwardBatch,
|
||||
*,
|
||||
start_token_index: int,
|
||||
end_token_index: int,
|
||||
start_seq_index: int,
|
||||
end_seq_index: int,
|
||||
output_attn_backend: AttentionBackend,
|
||||
):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
num_tokens = batch.input_ids.shape[0]
|
||||
num_seqs = batch.batch_size
|
||||
|
||||
output_dict = dict()
|
||||
|
||||
for key in [
|
||||
"input_ids",
|
||||
"positions",
|
||||
"out_cache_loc",
|
||||
]:
|
||||
old_value = getattr(batch, key)
|
||||
assert (
|
||||
old_value.shape[0] == num_tokens
|
||||
), f"{key=} {old_value=} {num_tokens=} {batch=}"
|
||||
output_dict[key] = old_value[start_token_index:end_token_index]
|
||||
|
||||
for key in [
|
||||
"req_pool_indices",
|
||||
"seq_lens",
|
||||
"seq_lens_cpu",
|
||||
"extend_seq_lens",
|
||||
"extend_prefix_lens",
|
||||
"extend_start_loc",
|
||||
"extend_prefix_lens_cpu",
|
||||
"extend_seq_lens_cpu",
|
||||
"extend_logprob_start_lens_cpu",
|
||||
"lora_paths",
|
||||
]:
|
||||
old_value = getattr(batch, key)
|
||||
if old_value is None:
|
||||
continue
|
||||
assert (
|
||||
len(old_value) == num_seqs
|
||||
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
||||
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
||||
|
||||
for key in [
|
||||
"forward_mode",
|
||||
"return_logprob",
|
||||
"req_to_token_pool",
|
||||
"token_to_kv_pool",
|
||||
"can_run_dp_cuda_graph",
|
||||
"global_forward_mode",
|
||||
"spec_info",
|
||||
"spec_algorithm",
|
||||
"capture_hidden_mode",
|
||||
"padded_static_len",
|
||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||
]:
|
||||
output_dict[key] = getattr(batch, key)
|
||||
|
||||
assert (
|
||||
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
||||
== batch.extend_num_tokens
|
||||
), f"{batch=}"
|
||||
extend_num_tokens = _compute_extend_num_tokens(
|
||||
output_dict["input_ids"], output_dict["forward_mode"]
|
||||
)
|
||||
|
||||
# TODO improve, e.g. unify w/ `init_raw`
|
||||
if global_server_args_dict["moe_dense_tp_size"] == 1:
|
||||
sum_len = end_token_index - start_token_index
|
||||
gathered_buffer = torch.zeros(
|
||||
(sum_len, batch.gathered_buffer.shape[1]),
|
||||
dtype=batch.gathered_buffer.dtype,
|
||||
device=batch.gathered_buffer.device,
|
||||
)
|
||||
else:
|
||||
gathered_buffer = None
|
||||
|
||||
output_dict.update(
|
||||
dict(
|
||||
batch_size=end_seq_index - start_seq_index,
|
||||
seq_lens_sum=(
|
||||
output_dict["seq_lens_cpu"].sum()
|
||||
if "seq_lens_cpu" in output_dict
|
||||
else None
|
||||
),
|
||||
extend_num_tokens=extend_num_tokens,
|
||||
attn_backend=output_attn_backend,
|
||||
tbo_split_seq_index=None,
|
||||
tbo_parent_token_range=(start_token_index, end_token_index),
|
||||
tbo_children=None,
|
||||
global_num_tokens_gpu=None,
|
||||
global_num_tokens_cpu=None,
|
||||
gathered_buffer=gathered_buffer,
|
||||
global_num_tokens_for_logprob_gpu=None,
|
||||
global_num_tokens_for_logprob_cpu=None,
|
||||
sampling_info=None,
|
||||
# For logits and logprobs post processing, thus we do not care
|
||||
temp_scaled_logprobs=False,
|
||||
temperature=None,
|
||||
top_p_normalized_logprobs=False,
|
||||
top_p=None,
|
||||
mm_inputs=None,
|
||||
num_token_non_padded=None,
|
||||
)
|
||||
)
|
||||
|
||||
errors = []
|
||||
for field in dataclasses.fields(ForwardBatch):
|
||||
if getattr(batch, field.name) is not None and field.name not in output_dict:
|
||||
errors.append(
|
||||
f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
|
||||
)
|
||||
if len(errors) > 0:
|
||||
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
|
||||
|
||||
return ForwardBatch(**output_dict)
|
||||
|
||||
|
||||
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
||||
if forward_mode.is_extend():
|
||||
return input_ids.shape[0]
|
||||
elif forward_mode.is_decode() or forward_mode.is_idle():
|
||||
return None
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# -------------------------------- Execution ---------------------------------------
|
||||
|
||||
|
||||
def model_forward_maybe_tbo(
|
||||
layers,
|
||||
enable_tbo: bool,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
inputs = dict(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
residual=residual,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
operations_strategy = OperationsStrategy.init_new_tbo(
|
||||
layers, forward_batch.global_forward_mode
|
||||
)
|
||||
if enable_tbo:
|
||||
return _model_forward_tbo(inputs, operations_strategy)
|
||||
else:
|
||||
return _model_forward_non_tbo(inputs, operations_strategy)
|
||||
|
||||
|
||||
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy):
|
||||
# The attn_tp_size!=1 case is not yet extracted to master
|
||||
assert get_attention_tp_size() == 1
|
||||
|
||||
inputs_arr = _model_forward_tbo_split_inputs(**inputs)
|
||||
del inputs
|
||||
|
||||
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
|
||||
outputs_arr = execute_overlapped_operations(
|
||||
inputs_arr=inputs_arr,
|
||||
operations_arr=[operations_strategy.operations] * 2,
|
||||
delta_stages=[0, operations_strategy.tbo_delta_stages],
|
||||
)
|
||||
|
||||
return _model_forward_tbo_merge_outputs(*outputs_arr)
|
||||
|
||||
|
||||
def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
|
||||
outputs = execute_operations(inputs, operations_strategy.operations)
|
||||
return outputs["hidden_states"], outputs["residual"]
|
||||
|
||||
|
||||
def _model_forward_tbo_split_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> List[Dict]:
|
||||
return [
|
||||
dict(
|
||||
**_model_forward_filter_inputs(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
positions=positions,
|
||||
output_forward_batch=output_forward_batch,
|
||||
tbo_subbatch_index=tbo_subbatch_index,
|
||||
),
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
for tbo_subbatch_index, output_forward_batch in enumerate(
|
||||
forward_batch.tbo_children
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _model_forward_filter_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
output_forward_batch: ForwardBatch,
|
||||
tbo_subbatch_index: int,
|
||||
) -> Dict:
|
||||
token_slice = slice(*output_forward_batch.tbo_parent_token_range)
|
||||
return dict(
|
||||
hidden_states=hidden_states[token_slice],
|
||||
residual=None if residual is None else residual[token_slice],
|
||||
positions=positions[token_slice],
|
||||
forward_batch=output_forward_batch,
|
||||
tbo_subbatch_index=tbo_subbatch_index,
|
||||
)
|
||||
|
||||
|
||||
def _model_forward_tbo_merge_outputs(output_a, output_b):
|
||||
def _handle_key(name):
|
||||
value_a = output_a[name]
|
||||
value_b = output_b[name]
|
||||
assert (value_a is None) == (value_b is None)
|
||||
if value_a is None:
|
||||
return None
|
||||
return torch.concat([value_a, value_b], dim=0)
|
||||
|
||||
return _handle_key("hidden_states"), _handle_key("residual")
|
||||
|
||||
|
||||
# -------------------------------- Utilities and wrappers ---------------------------------------
|
||||
|
||||
|
||||
class MaybeTboDeepEPDispatcher:
|
||||
def __init__(self, **kwargs):
|
||||
num_inner_dispatchers = (
|
||||
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
|
||||
)
|
||||
self._inners = [
|
||||
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
||||
]
|
||||
|
||||
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
||||
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
||||
|
||||
def dispatch(self, **kwargs):
|
||||
return self._execute("dispatch", **kwargs)
|
||||
|
||||
def dispatch_a(self, **kwargs):
|
||||
return self._execute("dispatch_a", **kwargs)
|
||||
|
||||
def dispatch_b(self, **kwargs):
|
||||
return self._execute("dispatch_b", **kwargs)
|
||||
|
||||
def combine(self, **kwargs):
|
||||
return self._execute("combine", **kwargs)
|
||||
|
||||
def combine_a(self, **kwargs):
|
||||
return self._execute("combine_a", **kwargs)
|
||||
|
||||
def combine_b(self, **kwargs):
|
||||
return self._execute("combine_b", **kwargs)
|
||||
Reference in New Issue
Block a user