feat: frequency, min_new_tokens, presence, and repetition penalties (#973)
This commit is contained in:
@@ -24,6 +24,7 @@ import numpy as np
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
@@ -222,8 +223,9 @@ class Req:
|
||||
)
|
||||
return
|
||||
|
||||
last_token_id = self.output_ids[-1]
|
||||
if (
|
||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||
last_token_id == self.tokenizer.eos_token_id
|
||||
and not self.sampling_params.ignore_eos
|
||||
):
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(
|
||||
@@ -231,6 +233,10 @@ class Req:
|
||||
)
|
||||
return
|
||||
|
||||
if last_token_id in self.sampling_params.stop_token_ids:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||
return
|
||||
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
tail_str = self.tokenizer.decode(
|
||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||
@@ -321,8 +327,7 @@ class ScheduleBatch:
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
frequency_penalties: torch.Tensor = None
|
||||
presence_penalties: torch.Tensor = None
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
logit_bias: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
@@ -386,15 +391,24 @@ class ScheduleBatch:
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
self.frequency_penalties = torch.tensor(
|
||||
[r.sampling_params.frequency_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.presence_penalties = torch.tensor(
|
||||
[r.sampling_params.presence_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=self,
|
||||
device=device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
@@ -617,6 +631,9 @@ class ScheduleBatch:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
||||
]
|
||||
else:
|
||||
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
||||
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
@@ -648,12 +665,12 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
@@ -674,12 +691,12 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
@@ -721,7 +738,8 @@ class ScheduleBatch:
|
||||
] = 1
|
||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
logits = self.penalizer_orchestrator.apply(logits)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
@@ -754,6 +772,8 @@ class ScheduleBatch:
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user