Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -13,9 +13,12 @@
|
||||
# ==============================================================================
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import collections
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
@@ -180,9 +189,13 @@ class ModelRunner:
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -199,6 +212,18 @@ class ModelRunner:
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
# Handle the case where some of models don't finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
|
||||
# Apply torchao quantization
|
||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||
# In layered loading, torchao may have been applied
|
||||
@@ -625,6 +650,9 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
@@ -655,6 +683,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
@@ -758,9 +787,13 @@ class ModelRunner:
|
||||
return
|
||||
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
logger.info(
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
@@ -820,11 +853,10 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
def _preprocess_logits(
|
||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
# Apply logit bias
|
||||
sampling_info = forward_batch.sampling_info
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
# in process_batch_result of the last batch.
|
||||
@@ -833,15 +865,77 @@ class ModelRunner:
|
||||
else:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||
|
||||
def update_output_logprobs(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
top_logprobs_nums: List[int],
|
||||
token_ids_logprobs: List[int],
|
||||
next_token_ids: torch.Tensor,
|
||||
*,
|
||||
num_tokens_per_req: List[int],
|
||||
):
|
||||
"""Update the logits_output's output logprob based on next_token_ids
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
sampling_info: Sampling info for logprob calculation
|
||||
top_logprobs_nums: Number of logprobs per request.
|
||||
next_token_ids: Next token ids.
|
||||
num_tokens_per_req: The number of tokens per request.
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
self._preprocess_logits(logits_output, sampling_info)
|
||||
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||
top_logprobs_nums_repeat_interleaved = []
|
||||
token_ids_logprobs_repeat_interleaved = []
|
||||
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||
self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
True,
|
||||
top_logprobs_nums_repeat_interleaved,
|
||||
token_ids_logprobs_repeat_interleaved,
|
||||
batch_next_token_ids=next_token_ids,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
"""Sample and compute logprobs and update logits_output.
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
forward_batch: The forward batch that generates logits_output
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
# For duplex models with multiple output streams.
|
||||
if isinstance(logits_output, tuple):
|
||||
return torch.stack(
|
||||
[self.sample(values, forward_batch) for values in logits_output],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||
|
||||
# Sample the next tokens
|
||||
next_token_ids = self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
forward_batch.sampling_info,
|
||||
forward_batch.return_logprob,
|
||||
forward_batch.top_logprobs_nums,
|
||||
forward_batch.token_ids_logprobs,
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user