Support overlapping two batches (#4068)

This commit is contained in:
fzyzcjy
2025-05-25 08:39:07 +08:00
committed by GitHub
parent f456037396
commit 0d47788025
13 changed files with 1145 additions and 129 deletions

View File

@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import tqdm
from sglang.srt import two_batch_overlap
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import (
TboCudaGraphRunnerUtils,
TboForwardBatchPreparer,
)
from sglang.srt.utils import (
get_available_gpu_memory,
get_device_memory_capacity,
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner.req_to_token_pool.size
]
if server_args.enable_two_batch_overlap:
capture_bs = [bs for bs in capture_bs if bs >= 2]
if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
if max(capture_bs) < server_args.cuda_graph_max_bs:
@@ -349,7 +357,14 @@ class CudaGraphRunner:
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported
is_tbo_supported = (
forward_batch.can_run_tbo
if self.model_runner.server_args.enable_two_batch_overlap
else True
)
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
def capture(self):
with graph_capture() as graph_capture_context:
@@ -466,7 +481,12 @@ class CudaGraphRunner:
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
num_token_non_padded=self.num_token_non_padded,
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
self, num_tokens
),
global_forward_mode=self.capture_forward_mode,
)
TboForwardBatchPreparer.prepare(forward_batch)
if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)

View File

@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import triton
@@ -239,6 +240,7 @@ class ForwardBatch:
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None
# Speculative decoding
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
@@ -252,12 +254,18 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None
tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None
@classmethod
def init_new(
cls,
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
@@ -281,6 +289,7 @@ class ForwardBatch:
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
req_to_token_pool=model_runner.req_to_token_pool,
@@ -294,6 +303,7 @@ class ForwardBatch:
num_token_non_padded=torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True),
tbo_split_seq_index=batch.tbo_split_seq_index,
)
# For DP attention
@@ -316,6 +326,7 @@ class ForwardBatch:
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret)
return ret
# Override the positions with spec_info
@@ -364,6 +375,8 @@ class ForwardBatch:
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
TboForwardBatchPreparer.prepare(ret)
return ret
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
@@ -588,6 +601,10 @@ class ForwardBatch:
# Precompute the kv indices for each chunk
self.prepare_chunked_kv_indices(device)
@property
def can_run_tbo(self):
return self.tbo_split_seq_index is not None
class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103

View File

@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
@@ -198,6 +199,7 @@ class ModelRunner:
"disable_radix_cache": server_args.disable_radix_cache,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_two_batch_overlap": server_args.enable_two_batch_overlap,
"enable_dp_lm_head": server_args.enable_dp_lm_head,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
@@ -994,6 +996,13 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.enable_two_batch_overlap:
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
else:
self.attn_backend = self._get_attention_backend()
# TODO unify with 6338
def _get_attention_backend(self):
if self.server_args.attention_backend == "flashinfer":
if not self.use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
@@ -1003,17 +1012,17 @@ class ModelRunner:
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
return FlashInferAttnBackend(self)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
return FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "aiter":
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
self.attn_backend = AiterAttnBackend(self)
return AiterAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
@@ -1028,21 +1037,21 @@ class ModelRunner:
DoubleSparseAttnBackend,
)
self.attn_backend = DoubleSparseAttnBackend(self)
return DoubleSparseAttnBackend(self)
else:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
self.attn_backend = TritonAttnBackend(self)
return TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
from sglang.srt.layers.attention.torch_native_backend import (
TorchNativeAttnBackend,
)
self.attn_backend = TorchNativeAttnBackend(self)
return TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
self.attn_backend = FlashMLABackend(self)
return FlashMLABackend(self)
elif self.server_args.attention_backend == "fa3":
assert (
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
@@ -1054,13 +1063,13 @@ class ModelRunner:
FlashAttentionBackend,
)
self.attn_backend = FlashAttentionBackend(self)
return FlashAttentionBackend(self)
elif self.server_args.attention_backend == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend,
)
self.attn_backend = CutlassMLABackend(self)
return CutlassMLABackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"