feat: frequency, min_new_tokens, presence, and repetition penalties (#973)

This commit is contained in:
Juwan Yoo
2024-08-08 04:21:08 -07:00
committed by GitHub
parent 228cf47547
commit ab7875941b
20 changed files with 1898 additions and 18 deletions

View File

@@ -24,6 +24,7 @@ import numpy as np
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
@@ -222,8 +223,9 @@ class Req:
)
return
last_token_id = self.output_ids[-1]
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
last_token_id == self.tokenizer.eos_token_id
and not self.sampling_params.ignore_eos
):
self.finished_reason = FINISH_MATCHED_TOKEN(
@@ -231,6 +233,10 @@ class Req:
)
return
if last_token_id in self.sampling_params.stop_token_ids:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
@@ -321,8 +327,7 @@ class ScheduleBatch:
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
frequency_penalties: torch.Tensor = None
presence_penalties: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None
@classmethod
@@ -386,15 +391,24 @@ class ScheduleBatch:
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.presence_penalties = torch.tensor(
[r.sampling_params.presence_penalty for r in reqs],
dtype=torch.float,
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=self,
device=device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
)
# Handle logit bias but only allocate when needed
@@ -617,6 +631,9 @@ class ScheduleBatch:
input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
]
else:
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1)
@@ -648,12 +665,12 @@ class ScheduleBatch:
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [
"temperatures",
"top_ps",
"top_ks",
"frequency_penalties",
"presence_penalties",
"logit_bias",
]:
self_val = getattr(self, item, None)
@@ -674,12 +691,12 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",
"top_ks",
"frequency_penalties",
"presence_penalties",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
@@ -721,7 +738,8 @@ class ScheduleBatch:
] = 1
logits[i].masked_fill_(~allowed_mask, float("-inf"))
# TODO(lmzheng): apply penalty
logits = self.penalizer_orchestrator.apply(logits)
probs = torch.softmax(logits, dim=-1)
if not global_server_args_dict["disable_flashinfer_sampling"]:
@@ -754,6 +772,8 @@ class ScheduleBatch:
req.regex_fsm_state, batch_next_token_ids_cpu[i]
)
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
return batch_next_token_ids