Support overlapping two batches (#4068)
This commit is contained in:
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
TboCudaGraphRunnerUtils,
|
||||
TboForwardBatchPreparer,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
get_device_memory_capacity,
|
||||
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
model_runner.req_to_token_pool.size
|
||||
]
|
||||
|
||||
if server_args.enable_two_batch_overlap:
|
||||
capture_bs = [bs for bs in capture_bs if bs >= 2]
|
||||
|
||||
if server_args.cuda_graph_max_bs:
|
||||
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
||||
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
||||
@@ -349,7 +357,14 @@ class CudaGraphRunner:
|
||||
if self.is_encoder_decoder
|
||||
else True
|
||||
)
|
||||
return is_bs_supported and is_encoder_lens_supported
|
||||
|
||||
is_tbo_supported = (
|
||||
forward_batch.can_run_tbo
|
||||
if self.model_runner.server_args.enable_two_batch_overlap
|
||||
else True
|
||||
)
|
||||
|
||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
@@ -466,7 +481,12 @@ class CudaGraphRunner:
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
lora_paths=lora_paths,
|
||||
num_token_non_padded=self.num_token_non_padded,
|
||||
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
|
||||
self, num_tokens
|
||||
),
|
||||
global_forward_mode=self.capture_forward_mode,
|
||||
)
|
||||
TboForwardBatchPreparer.prepare(forward_batch)
|
||||
|
||||
if lora_paths is not None:
|
||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -239,6 +240,7 @@ class ForwardBatch:
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
@@ -252,12 +254,18 @@ class ForwardBatch:
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
||||
tbo_children: Optional[List["ForwardBatch"]] = None
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
|
||||
|
||||
device = model_runner.device
|
||||
extend_input_logprob_token_ids_gpu = None
|
||||
if batch.extend_input_logprob_token_ids is not None:
|
||||
@@ -281,6 +289,7 @@ class ForwardBatch:
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
global_forward_mode=batch.global_forward_mode,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
@@ -294,6 +303,7 @@ class ForwardBatch:
|
||||
num_token_non_padded=torch.tensor(
|
||||
len(batch.input_ids), dtype=torch.int32
|
||||
).to(device, non_blocking=True),
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
)
|
||||
|
||||
# For DP attention
|
||||
@@ -316,6 +326,7 @@ class ForwardBatch:
|
||||
)
|
||||
if ret.forward_mode.is_idle():
|
||||
ret.positions = torch.empty((0,), device=device)
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
return ret
|
||||
|
||||
# Override the positions with spec_info
|
||||
@@ -364,6 +375,8 @@ class ForwardBatch:
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
TboForwardBatchPreparer.prepare(ret)
|
||||
|
||||
return ret
|
||||
|
||||
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
||||
@@ -588,6 +601,10 @@ class ForwardBatch:
|
||||
# Precompute the kv indices for each chunk
|
||||
self.prepare_chunked_kv_indices(device)
|
||||
|
||||
@property
|
||||
def can_run_tbo(self):
|
||||
return self.tbo_split_seq_index is not None
|
||||
|
||||
|
||||
class PPProxyTensors:
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
get_attention_tp_size,
|
||||
@@ -198,6 +199,7 @@ class ModelRunner:
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_two_batch_overlap": server_args.enable_two_batch_overlap,
|
||||
"enable_dp_lm_head": server_args.enable_dp_lm_head,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||
@@ -994,6 +996,13 @@ class ModelRunner:
|
||||
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.enable_two_batch_overlap:
|
||||
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
||||
else:
|
||||
self.attn_backend = self._get_attention_backend()
|
||||
|
||||
# TODO unify with 6338
|
||||
def _get_attention_backend(self):
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
if not self.use_mla_backend:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
@@ -1003,17 +1012,17 @@ class ModelRunner:
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
return FlashInferAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
return FlashInferMLAAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "aiter":
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
|
||||
self.attn_backend = AiterAttnBackend(self)
|
||||
return AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
@@ -1028,21 +1037,21 @@ class ModelRunner:
|
||||
DoubleSparseAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
return DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
return TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
from sglang.srt.layers.attention.torch_native_backend import (
|
||||
TorchNativeAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
return TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||
|
||||
self.attn_backend = FlashMLABackend(self)
|
||||
return FlashMLABackend(self)
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
assert (
|
||||
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
||||
@@ -1054,13 +1063,13 @@ class ModelRunner:
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashAttentionBackend(self)
|
||||
return FlashAttentionBackend(self)
|
||||
elif self.server_args.attention_backend == "cutlass_mla":
|
||||
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
||||
CutlassMLABackend,
|
||||
)
|
||||
|
||||
self.attn_backend = CutlassMLABackend(self)
|
||||
return CutlassMLABackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
Reference in New Issue
Block a user