Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -109,11 +109,15 @@ def set_torch_compile_config():
|
||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
server_args = model_runner.server_args
|
||||
capture_bs = server_args.cuda_graph_bs
|
||||
|
||||
if capture_bs is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 128]
|
||||
if server_args.speculative_algorithm is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
capture_bs = list(range(1, 33))
|
||||
|
||||
if is_hip_:
|
||||
capture_bs += [i * 8 for i in range(21, 33)]
|
||||
@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
capture_bs = [
|
||||
bs
|
||||
for bs in capture_bs
|
||||
@@ -385,9 +390,6 @@ class CudaGraphRunner:
|
||||
|
||||
run_once()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
@@ -401,12 +403,11 @@ class CudaGraphRunner:
|
||||
global_graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
hidden_mode_from_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
if (
|
||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
@@ -420,6 +421,9 @@ class CudaGraphRunner:
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
self.capture()
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
self.recapture_if_needed(forward_batch)
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -46,7 +46,8 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
|
||||
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
NULL = auto()
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = auto()
|
||||
# Capture a hidden state of the last token.
|
||||
LAST = auto()
|
||||
|
||||
def need_capture(self):
|
||||
@@ -148,6 +151,7 @@ class ForwardBatch:
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
@@ -160,6 +164,7 @@ class ForwardBatch:
|
||||
extend_prefix_lens_cpu: Optional[List[int]] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
@@ -190,10 +195,13 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: SpecInfo = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
# For padding
|
||||
padded_static_len: int = -1 # -1 if not padded
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
@@ -203,8 +211,13 @@ class ForwardBatch:
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
|
||||
device = model_runner.device
|
||||
extend_input_logprob_token_ids_gpu = None
|
||||
if batch.extend_input_logprob_token_ids is not None:
|
||||
extend_input_logprob_token_ids_gpu = (
|
||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
@@ -220,6 +233,7 @@ class ForwardBatch:
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
global_num_tokens=batch.global_num_tokens,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
@@ -231,6 +245,7 @@ class ForwardBatch:
|
||||
spec_info=batch.spec_info,
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
input_embeds=batch.input_embeds,
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
)
|
||||
|
||||
if ret.global_num_tokens is not None:
|
||||
@@ -341,6 +356,7 @@ class ForwardBatch:
|
||||
)
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
|
||||
self.mrope_positions = torch.concat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
@@ -379,7 +395,7 @@ def compute_position_kernel(
|
||||
extend_seq_lens,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
pid = tl.program_id(0).to(tl.int64)
|
||||
|
||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
# ==============================================================================
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import collections
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
@@ -180,9 +189,13 @@ class ModelRunner:
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -199,6 +212,18 @@ class ModelRunner:
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
# Handle the case where some of models don't finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
|
||||
# Apply torchao quantization
|
||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||
# In layered loading, torchao may have been applied
|
||||
@@ -625,6 +650,9 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
@@ -655,6 +683,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
@@ -758,9 +787,13 @@ class ModelRunner:
|
||||
return
|
||||
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
logger.info(
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
@@ -820,11 +853,10 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
def _preprocess_logits(
|
||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
# Apply logit bias
|
||||
sampling_info = forward_batch.sampling_info
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
# in process_batch_result of the last batch.
|
||||
@@ -833,15 +865,77 @@ class ModelRunner:
|
||||
else:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||
|
||||
def update_output_logprobs(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
top_logprobs_nums: List[int],
|
||||
token_ids_logprobs: List[int],
|
||||
next_token_ids: torch.Tensor,
|
||||
*,
|
||||
num_tokens_per_req: List[int],
|
||||
):
|
||||
"""Update the logits_output's output logprob based on next_token_ids
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
sampling_info: Sampling info for logprob calculation
|
||||
top_logprobs_nums: Number of logprobs per request.
|
||||
next_token_ids: Next token ids.
|
||||
num_tokens_per_req: The number of tokens per request.
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
self._preprocess_logits(logits_output, sampling_info)
|
||||
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||
top_logprobs_nums_repeat_interleaved = []
|
||||
token_ids_logprobs_repeat_interleaved = []
|
||||
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||
self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
True,
|
||||
top_logprobs_nums_repeat_interleaved,
|
||||
token_ids_logprobs_repeat_interleaved,
|
||||
batch_next_token_ids=next_token_ids,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
"""Sample and compute logprobs and update logits_output.
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
forward_batch: The forward batch that generates logits_output
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
# For duplex models with multiple output streams.
|
||||
if isinstance(logits_output, tuple):
|
||||
return torch.stack(
|
||||
[self.sample(values, forward_batch) for values in logits_output],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||
|
||||
# Sample the next tokens
|
||||
next_token_ids = self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
forward_batch.sampling_info,
|
||||
forward_batch.return_logprob,
|
||||
forward_batch.top_logprobs_nums,
|
||||
forward_batch.token_ids_logprobs,
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user