Add an option to disable penalizer (#1651)
This commit is contained in:
@@ -531,7 +531,9 @@ class ScheduleBatch:
|
|||||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
|
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
|
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
||||||
|
)
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
self.forward_mode = ForwardMode.MIXED
|
self.forward_mode = ForwardMode.MIXED
|
||||||
|
|||||||
@@ -671,9 +671,10 @@ class Scheduler:
|
|||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
if batch.sampling_info.penalizer_orchestrator:
|
||||||
next_token_ids
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
)
|
next_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
if logits_output:
|
if logits_output:
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
@@ -755,9 +756,10 @@ class Scheduler:
|
|||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
if batch.sampling_info.penalizer_orchestrator:
|
||||||
next_token_ids
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
)
|
next_token_ids
|
||||||
|
)
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ class ModelRunner:
|
|||||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||||
"disable_mla": server_args.disable_mla,
|
"disable_mla": server_args.disable_mla,
|
||||||
"torchao_config": server_args.torchao_config,
|
"torchao_config": server_args.torchao_config,
|
||||||
|
"disable_penalizer": server_args.disable_penalizer,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -33,15 +33,20 @@ class SamplingBatchInfo:
|
|||||||
regex_fsm_states: List[int] = None
|
regex_fsm_states: List[int] = None
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||||
linear_penalties: torch.Tensor = None
|
linear_penalties: Optional[torch.Tensor] = None
|
||||||
scaling_penalties: torch.Tensor = None
|
scaling_penalties: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(
|
||||||
|
cls,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
vocab_size: int,
|
||||||
|
disable_penalizer: bool,
|
||||||
|
):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
with batch.input_ids.device:
|
with batch.input_ids.device:
|
||||||
temperatures = torch.tensor(
|
temperatures = torch.tensor(
|
||||||
@@ -76,17 +81,20 @@ class SamplingBatchInfo:
|
|||||||
# While we choose not to even create the class instances if they are not required, this
|
# While we choose not to even create the class instances if they are not required, this
|
||||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||||
# handle {filter_batch()} and {merge()} cases as well.
|
# handle {filter_batch()} and {merge()} cases as well.
|
||||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
if disable_penalizer:
|
||||||
vocab_size=vocab_size,
|
ret.penalizer_orchestrator = None
|
||||||
batch=batch,
|
else:
|
||||||
device=batch.input_ids.device,
|
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||||
Penalizers={
|
vocab_size=vocab_size,
|
||||||
penaltylib.BatchedFrequencyPenalizer,
|
batch=batch,
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
device=batch.input_ids.device,
|
||||||
penaltylib.BatchedPresencePenalizer,
|
Penalizers={
|
||||||
penaltylib.BatchedRepetitionPenalizer,
|
penaltylib.BatchedFrequencyPenalizer,
|
||||||
},
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
)
|
penaltylib.BatchedPresencePenalizer,
|
||||||
|
penaltylib.BatchedRepetitionPenalizer,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Handle logit bias but only allocate when needed
|
# Handle logit bias but only allocate when needed
|
||||||
ret.logit_bias = None
|
ret.logit_bias = None
|
||||||
@@ -97,6 +105,9 @@ class SamplingBatchInfo:
|
|||||||
return len(self.temperatures)
|
return len(self.temperatures)
|
||||||
|
|
||||||
def update_penalties(self):
|
def update_penalties(self):
|
||||||
|
if not self.penalizer_orchestrator:
|
||||||
|
return
|
||||||
|
|
||||||
self.scaling_penalties = None
|
self.scaling_penalties = None
|
||||||
self.linear_penalties = None
|
self.linear_penalties = None
|
||||||
|
|
||||||
@@ -117,26 +128,26 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||||
|
if not has_regex:
|
||||||
|
self.vocab_mask = None
|
||||||
|
return
|
||||||
|
|
||||||
# Reset the vocab mask
|
self.vocab_mask = torch.zeros(
|
||||||
self.vocab_mask = None
|
len(self.temperatures),
|
||||||
|
self.vocab_size,
|
||||||
if has_regex:
|
dtype=torch.bool,
|
||||||
self.vocab_mask = torch.zeros(
|
device=self.device,
|
||||||
len(self.temperatures),
|
)
|
||||||
self.vocab_size,
|
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||||
dtype=torch.bool,
|
if regex_fsm is not None:
|
||||||
device=self.device,
|
self.vocab_mask[i].fill_(1)
|
||||||
)
|
self.vocab_mask[i][
|
||||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||||
if regex_fsm is not None:
|
] = 0
|
||||||
self.vocab_mask[i].fill_(1)
|
|
||||||
self.vocab_mask[i][
|
|
||||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
|
||||||
] = 0
|
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
if self.penalizer_orchestrator:
|
||||||
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
@@ -175,7 +186,8 @@ class SamplingBatchInfo:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
if self.penalizer_orchestrator:
|
||||||
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
|
|||||||
@@ -35,12 +35,12 @@ class ServerArgs:
|
|||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
dtype: str = "auto"
|
|
||||||
device: str = "cuda"
|
|
||||||
kv_cache_dtype: str = "auto"
|
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
context_length: Optional[int] = None
|
dtype: str = "auto"
|
||||||
|
kv_cache_dtype: str = "auto"
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
|
context_length: Optional[int] = None
|
||||||
|
device: str = "cuda"
|
||||||
served_model_name: Optional[str] = None
|
served_model_name: Optional[str] = None
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
is_embedding: bool = False
|
is_embedding: bool = False
|
||||||
@@ -86,10 +86,15 @@ class ServerArgs:
|
|||||||
# Model override args in JSON
|
# Model override args in JSON
|
||||||
json_model_override_args: str = "{}"
|
json_model_override_args: str = "{}"
|
||||||
|
|
||||||
# Optimization/debug options
|
# LoRA
|
||||||
|
lora_paths: Optional[List[str]] = None
|
||||||
|
max_loras_per_batch: int = 8
|
||||||
|
|
||||||
|
# Kernel backend
|
||||||
attention_backend: Optional[str] = None
|
attention_backend: Optional[str] = None
|
||||||
sampling_backend: Optional[str] = None
|
sampling_backend: Optional[str] = None
|
||||||
|
|
||||||
|
# Optimization/debug options
|
||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
disable_flashinfer_sampling: bool = False
|
disable_flashinfer_sampling: bool = False
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
@@ -99,6 +104,7 @@ class ServerArgs:
|
|||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
disable_mla: bool = False
|
disable_mla: bool = False
|
||||||
|
disable_penalizer: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
max_torch_compile_bs: int = 32
|
max_torch_compile_bs: int = 32
|
||||||
@@ -106,10 +112,6 @@ class ServerArgs:
|
|||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
|
|
||||||
# LoRA
|
|
||||||
lora_paths: Optional[List[str]] = None
|
|
||||||
max_loras_per_batch: int = 8
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -224,6 +226,11 @@ class ServerArgs:
|
|||||||
'"dummy" will initialize the weights with random values, '
|
'"dummy" will initialize the weights with random values, '
|
||||||
"which is mainly for profiling.",
|
"which is mainly for profiling.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -238,13 +245,6 @@ class ServerArgs:
|
|||||||
'* "float" is shorthand for FP32 precision.\n'
|
'* "float" is shorthand for FP32 precision.\n'
|
||||||
'* "float32" for FP32 precision.',
|
'* "float32" for FP32 precision.',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
choices=["cuda", "xpu"],
|
|
||||||
help="The device type.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -252,17 +252,6 @@ class ServerArgs:
|
|||||||
choices=["auto", "fp8_e5m2"],
|
choices=["auto", "fp8_e5m2"],
|
||||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--trust-remote-code",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-length",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.context_length,
|
|
||||||
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantization",
|
"--quantization",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -278,6 +267,19 @@ class ServerArgs:
|
|||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-length",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.context_length,
|
||||||
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
choices=["cuda", "xpu"],
|
||||||
|
help="The device type.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--served-model-name",
|
"--served-model-name",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -440,7 +442,23 @@ class ServerArgs:
|
|||||||
default=ServerArgs.json_model_override_args,
|
default=ServerArgs.json_model_override_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# LoRA
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-paths",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=None,
|
||||||
|
action=LoRAPathAction,
|
||||||
|
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-loras-per-batch",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Maximum number of adapters for a running batch, include base-only request",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kernel backend
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -455,6 +473,8 @@ class ServerArgs:
|
|||||||
default=ServerArgs.sampling_backend,
|
default=ServerArgs.sampling_backend,
|
||||||
help="Choose the kernels for sampling layers.",
|
help="Choose the kernels for sampling layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-flashinfer",
|
"--disable-flashinfer",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -501,6 +521,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-penalizer",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-mixed-chunk",
|
"--enable-mixed-chunk",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -534,27 +559,6 @@ class ServerArgs:
|
|||||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||||
"This only affects Triton attention kernels.",
|
"This only affects Triton attention kernels.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--efficient-weight-load",
|
|
||||||
action="store_true",
|
|
||||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA options
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora-paths",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=None,
|
|
||||||
action=LoRAPathAction,
|
|
||||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-loras-per-batch",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Maximum number of adapters for a running batch, include base-only request",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user