From 9da5a60b18bcd0331a7b54e89d3d697db599f924 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 12 Oct 2024 17:53:23 -0700 Subject: [PATCH] Add an option to disable penalizer (#1651) --- python/sglang/srt/managers/schedule_batch.py | 4 +- python/sglang/srt/managers/scheduler.py | 14 +-- .../sglang/srt/model_executor/model_runner.py | 1 + .../srt/sampling/sampling_batch_info.py | 80 ++++++++------ python/sglang/srt/server_args.py | 102 +++++++++--------- 5 files changed, 111 insertions(+), 90 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a2f48131d..afe356b38 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -531,7 +531,9 @@ class ScheduleBatch: self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] - self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) + self.sampling_info = SamplingBatchInfo.from_schedule_batch( + self, vocab_size, global_server_args_dict["disable_penalizer"] + ) def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a4ada01aa..27761c8f0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -671,9 +671,10 @@ class Scheduler: def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: logits_output, next_token_ids = result - batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - next_token_ids - ) + if batch.sampling_info.penalizer_orchestrator: + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) if logits_output: # Move logprobs to cpu @@ -755,9 +756,10 @@ class Scheduler: def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids = result - batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - next_token_ids - ) + if batch.sampling_info.penalizer_orchestrator: + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) self.num_generated_tokens += len(batch.reqs) # Move logprobs to cpu diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5ca49cc17..a338fdca9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -119,6 +119,7 @@ class ModelRunner: "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, + "disable_penalizer": server_args.disable_penalizer, } ) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index c89608c36..4a38fc087 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import torch @@ -33,15 +33,20 @@ class SamplingBatchInfo: regex_fsm_states: List[int] = None # Penalizer - penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None - linear_penalties: torch.Tensor = None - scaling_penalties: torch.Tensor = None + penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None + linear_penalties: Optional[torch.Tensor] = None + scaling_penalties: Optional[torch.Tensor] = None # Device device: str = "cuda" @classmethod - def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + def from_schedule_batch( + cls, + batch: ScheduleBatch, + vocab_size: int, + disable_penalizer: bool, + ): reqs = batch.reqs with batch.input_ids.device: temperatures = torch.tensor( @@ -76,17 +81,20 @@ class SamplingBatchInfo: # While we choose not to even create the class instances if they are not required, this # could add additional complexity to the {ScheduleBatch} class, especially we need to # handle {filter_batch()} and {merge()} cases as well. - ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( - vocab_size=vocab_size, - batch=batch, - device=batch.input_ids.device, - Penalizers={ - penaltylib.BatchedFrequencyPenalizer, - penaltylib.BatchedMinNewTokensPenalizer, - penaltylib.BatchedPresencePenalizer, - penaltylib.BatchedRepetitionPenalizer, - }, - ) + if disable_penalizer: + ret.penalizer_orchestrator = None + else: + ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=batch, + device=batch.input_ids.device, + Penalizers={ + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, + }, + ) # Handle logit bias but only allocate when needed ret.logit_bias = None @@ -97,6 +105,9 @@ class SamplingBatchInfo: return len(self.temperatures) def update_penalties(self): + if not self.penalizer_orchestrator: + return + self.scaling_penalties = None self.linear_penalties = None @@ -117,26 +128,26 @@ class SamplingBatchInfo: def update_regex_vocab_mask(self): has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) + if not has_regex: + self.vocab_mask = None + return - # Reset the vocab mask - self.vocab_mask = None - - if has_regex: - self.vocab_mask = torch.zeros( - len(self.temperatures), - self.vocab_size, - dtype=torch.bool, - device=self.device, - ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: - self.vocab_mask[i].fill_(1) - self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens - ] = 0 + self.vocab_mask = torch.zeros( + len(self.temperatures), + self.vocab_size, + dtype=torch.bool, + device=self.device, + ) + for i, regex_fsm in enumerate(self.regex_fsms): + if regex_fsm is not None: + self.vocab_mask[i].fill_(1) + self.vocab_mask[i][ + regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens + ] = 0 def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): - self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + if self.penalizer_orchestrator: + self.penalizer_orchestrator.filter(unfinished_indices, new_indices) for item in [ "temperatures", @@ -175,7 +186,8 @@ class SamplingBatchInfo: return None def merge_batch(self, other: "SamplingBatchInfo"): - self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + if self.penalizer_orchestrator: + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) for item in [ "temperatures", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a07bc04be..2966bed64 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -35,12 +35,12 @@ class ServerArgs: tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False load_format: str = "auto" - dtype: str = "auto" - device: str = "cuda" - kv_cache_dtype: str = "auto" trust_remote_code: bool = True - context_length: Optional[int] = None + dtype: str = "auto" + kv_cache_dtype: str = "auto" quantization: Optional[str] = None + context_length: Optional[int] = None + device: str = "cuda" served_model_name: Optional[str] = None chat_template: Optional[str] = None is_embedding: bool = False @@ -86,10 +86,15 @@ class ServerArgs: # Model override args in JSON json_model_override_args: str = "{}" - # Optimization/debug options + # LoRA + lora_paths: Optional[List[str]] = None + max_loras_per_batch: int = 8 + + # Kernel backend attention_backend: Optional[str] = None sampling_backend: Optional[str] = None + # Optimization/debug options disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False @@ -99,6 +104,7 @@ class ServerArgs: disable_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False + disable_penalizer: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False max_torch_compile_bs: int = 32 @@ -106,10 +112,6 @@ class ServerArgs: enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False - # LoRA - lora_paths: Optional[List[str]] = None - max_loras_per_batch: int = 8 - def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -224,6 +226,11 @@ class ServerArgs: '"dummy" will initialize the weights with random values, ' "which is mainly for profiling.", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) parser.add_argument( "--dtype", type=str, @@ -238,13 +245,6 @@ class ServerArgs: '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) - parser.add_argument( - "--device", - type=str, - default="cuda", - choices=["cuda", "xpu"], - help="The device type.", - ) parser.add_argument( "--kv-cache-dtype", type=str, @@ -252,17 +252,6 @@ class ServerArgs: choices=["auto", "fp8_e5m2"], help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", - ) - parser.add_argument( - "--context-length", - type=int, - default=ServerArgs.context_length, - help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", - ) parser.add_argument( "--quantization", type=str, @@ -278,6 +267,19 @@ class ServerArgs: ], help="The quantization method.", ) + parser.add_argument( + "--context-length", + type=int, + default=ServerArgs.context_length, + help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "xpu"], + help="The device type.", + ) parser.add_argument( "--served-model-name", type=str, @@ -440,7 +442,23 @@ class ServerArgs: default=ServerArgs.json_model_override_args, ) - # Optimization/debug options + # LoRA + parser.add_argument( + "--lora-paths", + type=str, + nargs="*", + default=None, + action=LoRAPathAction, + help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", + ) + parser.add_argument( + "--max-loras-per-batch", + type=int, + default=8, + help="Maximum number of adapters for a running batch, include base-only request", + ) + + # Kernel backend parser.add_argument( "--attention-backend", type=str, @@ -455,6 +473,8 @@ class ServerArgs: default=ServerArgs.sampling_backend, help="Choose the kernels for sampling layers.", ) + + # Optimization/debug options parser.add_argument( "--disable-flashinfer", action="store_true", @@ -501,6 +521,11 @@ class ServerArgs: action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) + parser.add_argument( + "--disable-penalizer", + action="store_true", + help="Disable the logit penalizer (e.g., frequency and repetition penalty).", + ) parser.add_argument( "--enable-mixed-chunk", action="store_true", @@ -534,27 +559,6 @@ class ServerArgs: help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) - parser.add_argument( - "--efficient-weight-load", - action="store_true", - help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", - ) - - # LoRA options - parser.add_argument( - "--lora-paths", - type=str, - nargs="*", - default=None, - action=LoRAPathAction, - help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", - ) - parser.add_argument( - "--max-loras-per-batch", - type=int, - default=8, - help="Maximum number of adapters for a running batch, include base-only request", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace):