Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -109,11 +109,15 @@ def set_torch_compile_config():
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128]
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
capture_bs = list(range(1, 33))
if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)]
@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
)
)
)
capture_bs = [
bs
for bs in capture_bs
@@ -385,9 +390,6 @@ class CudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
@@ -401,12 +403,11 @@ class CudaGraphRunner:
global_graph_memory_pool = graph.pool()
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
# If the capture_hidden_mode changes, we need to recapture the graph
if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
@@ -420,6 +421,9 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs

View File

@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
@@ -46,7 +46,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
class ForwardMode(IntEnum):
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum):
NULL = auto()
# Capture hidden states of all tokens.
FULL = auto()
# Capture a hidden state of the last token.
LAST = auto()
def need_capture(self):
@@ -148,6 +151,7 @@ class ForwardBatch:
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# Position information
positions: torch.Tensor = None
@@ -160,6 +164,7 @@ class ForwardBatch:
extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
@@ -190,10 +195,13 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False
# Speculative decoding
spec_info: SpecInfo = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL
mrope_positions: torch.Tensor = None
@@ -203,8 +211,13 @@ class ForwardBatch:
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
@@ -220,6 +233,7 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
@@ -231,6 +245,7 @@ class ForwardBatch:
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
if ret.global_num_tokens is not None:
@@ -341,6 +356,7 @@ class ForwardBatch:
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
@@ -379,7 +395,7 @@ def compute_position_kernel(
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)

View File

@@ -13,9 +13,12 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import json
import logging
import os
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
@@ -180,9 +189,13 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
}
)
@@ -199,6 +212,18 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
# Handle the case where some of models don't finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
@@ -625,6 +650,9 @@ class ModelRunner:
4096,
)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -655,6 +683,7 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
@@ -758,9 +787,13 @@ class ModelRunner:
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -820,11 +853,10 @@ class ModelRunner:
else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
):
# Apply logit bias
sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
@@ -833,15 +865,77 @@ class ModelRunner:
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
def update_output_logprobs(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
top_logprobs_nums: List[int],
token_ids_logprobs: List[int],
next_token_ids: torch.Tensor,
*,
num_tokens_per_req: List[int],
):
"""Update the logits_output's output logprob based on next_token_ids
Args:
logits_output: The logits output from the model forward
sampling_info: Sampling info for logprob calculation
top_logprobs_nums: Number of logprobs per request.
next_token_ids: Next token ids.
num_tokens_per_req: The number of tokens per request.
Returns:
A list of next_token_ids
"""
self._preprocess_logits(logits_output, sampling_info)
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
self.sampler(
logits_output,
sampling_info,
True,
top_logprobs_nums_repeat_interleaved,
token_ids_logprobs_repeat_interleaved,
batch_next_token_ids=next_token_ids,
)
def sample(
self,
logits_output: LogitsProcessorOutput,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Sample and compute logprobs and update logits_output.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
Returns:
A list of next_token_ids
"""
# For duplex models with multiple output streams.
if isinstance(logits_output, tuple):
return torch.stack(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens
next_token_ids = self.sampler(
logits_output,
sampling_info,
forward_batch.sampling_info,
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
return next_token_ids