Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -13,9 +13,12 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import json
import logging
import os
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
@@ -180,9 +189,13 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
}
)
@@ -199,6 +212,18 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
# Handle the case where some of models don't finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
@@ -625,6 +650,9 @@ class ModelRunner:
4096,
)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -655,6 +683,7 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
@@ -758,9 +787,13 @@ class ModelRunner:
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -820,11 +853,10 @@ class ModelRunner:
else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
):
# Apply logit bias
sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
@@ -833,15 +865,77 @@ class ModelRunner:
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
def update_output_logprobs(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
top_logprobs_nums: List[int],
token_ids_logprobs: List[int],
next_token_ids: torch.Tensor,
*,
num_tokens_per_req: List[int],
):
"""Update the logits_output's output logprob based on next_token_ids
Args:
logits_output: The logits output from the model forward
sampling_info: Sampling info for logprob calculation
top_logprobs_nums: Number of logprobs per request.
next_token_ids: Next token ids.
num_tokens_per_req: The number of tokens per request.
Returns:
A list of next_token_ids
"""
self._preprocess_logits(logits_output, sampling_info)
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
self.sampler(
logits_output,
sampling_info,
True,
top_logprobs_nums_repeat_interleaved,
token_ids_logprobs_repeat_interleaved,
batch_next_token_ids=next_token_ids,
)
def sample(
self,
logits_output: LogitsProcessorOutput,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Sample and compute logprobs and update logits_output.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
Returns:
A list of next_token_ids
"""
# For duplex models with multiple output streams.
if isinstance(logits_output, tuple):
return torch.stack(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens
next_token_ids = self.sampler(
logits_output,
sampling_info,
forward_batch.sampling_info,
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
return next_token_ids