Make constrained decoding work for overlap scheduler (#2095)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -136,6 +136,7 @@ class ImageInputs:
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# QWen2-VL related
|
||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
@@ -187,11 +188,10 @@ class Req:
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
|
||||
# Memory info
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
|
||||
# Check finish
|
||||
@@ -428,7 +428,7 @@ bid = 0
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
"""Store all inforamtion of a batch on the scheduler."""
|
||||
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
@@ -438,9 +438,9 @@ class ScheduleBatch:
|
||||
|
||||
# For utility
|
||||
model_config: ModelConfig = None
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
@@ -509,7 +509,7 @@ class ScheduleBatch:
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def alloc_req_slots(self, num_reqs):
|
||||
def alloc_req_slots(self, num_reqs: int):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
@@ -610,7 +610,7 @@ class ScheduleBatch:
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
|
||||
def prepare_for_extend(self):
|
||||
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = len(self.reqs)
|
||||
@@ -704,7 +704,7 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
global_server_args_dict["disable_penalizer"],
|
||||
enable_overlap_schedule=enable_overlap_schedule,
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -746,6 +746,7 @@ class ScheduleBatch:
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
"""Retract the decoding requests when there is not enough memory."""
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
@@ -886,18 +887,10 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_idle(self):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
|
||||
@@ -1063,7 +1056,6 @@ class ScheduleBatch:
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
sampling_info=self.sampling_info,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
@@ -220,8 +222,12 @@ class Scheduler:
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
self.cur_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
@@ -336,15 +342,12 @@ class Scheduler:
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
self.last_batch = None
|
||||
|
||||
"""A normal scheduler loop."""
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
@@ -353,20 +356,8 @@ class Scheduler:
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
# Decode multiple steps to reduce the overhead
|
||||
if batch.forward_mode.is_decode():
|
||||
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
||||
if not self.running_batch:
|
||||
break
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
break
|
||||
if self.server_args.enable_dp_attention:
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -377,9 +368,6 @@ class Scheduler:
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
|
||||
self.last_batch = None
|
||||
self.running_batch = None
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
@@ -390,10 +378,24 @@ class Scheduler:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -806,7 +808,7 @@ class Scheduler:
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
new_batch.prepare_for_extend(self.enable_overlap)
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
@@ -893,14 +895,15 @@ class Scheduler:
|
||||
return ret
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
else:
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
|
||||
@@ -953,6 +956,10 @@ class Scheduler:
|
||||
else:
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings, bid = result
|
||||
embeddings = embeddings.tolist()
|
||||
@@ -1022,6 +1029,10 @@ class Scheduler:
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
|
||||
@@ -18,7 +18,6 @@ limitations under the License.
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
@@ -96,9 +95,7 @@ class TpModelWorkerClient:
|
||||
@torch.no_grad()
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
model_worker_batch, future_token_ids_ct, compute_info_done = (
|
||||
self.input_queue.get()
|
||||
)
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
if not model_worker_batch:
|
||||
break
|
||||
self.launch_done = threading.Event()
|
||||
@@ -109,7 +106,6 @@ class TpModelWorkerClient:
|
||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||
|
||||
# Run forward
|
||||
compute_info_done.wait()
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch, self.launch_done
|
||||
)
|
||||
@@ -160,15 +156,16 @@ class TpModelWorkerClient:
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
_ = model_worker_batch.seq_lens[0].item()
|
||||
|
||||
# Push a new batch to the queue
|
||||
model_worker_batch.sampling_info = dataclasses.replace(
|
||||
model_worker_batch.sampling_info
|
||||
)
|
||||
compute_info_done = torch.cuda.Event()
|
||||
compute_info_done.record()
|
||||
self.input_queue.put(
|
||||
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
|
||||
model_worker_batch.sampling_info,
|
||||
sampling_info_done=threading.Event(),
|
||||
)
|
||||
self.cur_sampling_info = model_worker_batch.sampling_info
|
||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
|
||||
@@ -52,15 +52,19 @@ if TYPE_CHECKING:
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
# Contains both EXTEND and DECODE.
|
||||
# Contains both EXTEND and DECODE when doing chunked prefill.
|
||||
MIXED = auto()
|
||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
|
||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
|
||||
IDLE = auto()
|
||||
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||
DUMMY_FIRST = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
return self == ForwardMode.PREFILL
|
||||
|
||||
@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
|
||||
def is_idle(self):
|
||||
return self == ForwardMode.IDLE
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatch:
|
||||
|
||||
@@ -142,7 +142,6 @@ class ModelRunner:
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"disable_mla": server_args.disable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"disable_penalizer": server_args.disable_penalizer,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
}
|
||||
@@ -636,10 +635,18 @@ class ModelRunner:
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info = forward_batch.sampling_info
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
# in process_batch_result of the last batch.
|
||||
if sampling_info.grammars:
|
||||
sampling_info.sampling_info_done.wait()
|
||||
sampling_info.update_penalties()
|
||||
else:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||
|
||||
# Sample the next tokens.
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
@@ -28,6 +33,7 @@ class SamplingBatchInfo:
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
sampling_info_done: Optional[threading.Event] = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls,
|
||||
batch: ScheduleBatch,
|
||||
vocab_size: int,
|
||||
disable_penalizer: bool,
|
||||
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
||||
):
|
||||
reqs = batch.reqs
|
||||
device = batch.device
|
||||
@@ -79,6 +82,33 @@ class SamplingBatchInfo:
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
if enable_overlap_schedule:
|
||||
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
||||
# so it is kind of tricky to make it work with overlap scheduler.
|
||||
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
||||
# We will support them later.
|
||||
penalizers = {
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
}
|
||||
if (
|
||||
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
||||
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
||||
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
||||
):
|
||||
logger.warning(
|
||||
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
||||
"when using the default overlap scheduler. They will be ignored. "
|
||||
"Please add `--disable-overlap` when launching the server if you need these features. "
|
||||
"The speed will be slower in that case."
|
||||
)
|
||||
else:
|
||||
penalizers = {
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
}
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
if disable_penalizer:
|
||||
ret.penalizer_orchestrator = None
|
||||
else:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.device,
|
||||
Penalizers=penalizers,
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
@@ -133,13 +155,13 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
if not self.grammars or not any(grammar for grammar in self.grammars):
|
||||
if not self.grammars:
|
||||
self.vocab_mask = None
|
||||
self.apply_mask = None
|
||||
return
|
||||
|
||||
# find a grammar from the list
|
||||
grammar = next(grammar for grammar in self.grammars if grammar is not None)
|
||||
grammar = next(grammar for grammar in self.grammars if grammar)
|
||||
|
||||
# maybe we can reuse the existing mask?
|
||||
self.vocab_mask = grammar.allocate_vocab_mask(
|
||||
|
||||
@@ -123,7 +123,6 @@ class ServerArgs:
|
||||
disable_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
disable_penalizer: bool = False
|
||||
enable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
@@ -200,12 +199,7 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
if self.enable_overlap_schedule:
|
||||
logger.warning(
|
||||
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||
"and embedding APIs are not supported and will lead to wrong results. "
|
||||
)
|
||||
self.disable_penalizer = True
|
||||
self.disable_jump_forward = True
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -622,11 +616,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-penalizer",
|
||||
action="store_true",
|
||||
help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-nan-detection",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user