feat: frequency, min_new_tokens, presence, and repetition penalties (#973)
This commit is contained in:
@@ -24,7 +24,7 @@ import warnings
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -47,6 +47,7 @@ class RequestFuncInput:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
extra_request_body: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -84,6 +85,7 @@ async def async_request_trt_llm(
|
||||
"stream": True,
|
||||
"min_length": request_func_input.output_len,
|
||||
"end_id": 1048576,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
if args.disable_ignore_eos:
|
||||
del payload["min_length"]
|
||||
@@ -154,6 +156,7 @@ async def async_request_openai_completions(
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": not args.disable_stream,
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
|
||||
@@ -542,6 +545,7 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
enable_multi: bool,
|
||||
extra_request_body: Dict[str, Any],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -556,6 +560,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
@@ -578,6 +583,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
@@ -746,6 +752,10 @@ def fire(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
extra_request_body = {}
|
||||
if args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
if args.port is None:
|
||||
args.port = {
|
||||
"sglang": 30000,
|
||||
@@ -838,6 +848,7 @@ def fire(args: argparse.Namespace):
|
||||
request_rate=rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -851,6 +862,7 @@ def fire(args: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -976,6 +988,13 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
type=str,
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
|
||||
set_ulimit()
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import numpy as np
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
@@ -222,8 +223,9 @@ class Req:
|
||||
)
|
||||
return
|
||||
|
||||
last_token_id = self.output_ids[-1]
|
||||
if (
|
||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||
last_token_id == self.tokenizer.eos_token_id
|
||||
and not self.sampling_params.ignore_eos
|
||||
):
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(
|
||||
@@ -231,6 +233,10 @@ class Req:
|
||||
)
|
||||
return
|
||||
|
||||
if last_token_id in self.sampling_params.stop_token_ids:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||
return
|
||||
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
tail_str = self.tokenizer.decode(
|
||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||
@@ -321,8 +327,7 @@ class ScheduleBatch:
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
frequency_penalties: torch.Tensor = None
|
||||
presence_penalties: torch.Tensor = None
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
logit_bias: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
@@ -386,15 +391,24 @@ class ScheduleBatch:
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
self.frequency_penalties = torch.tensor(
|
||||
[r.sampling_params.frequency_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.presence_penalties = torch.tensor(
|
||||
[r.sampling_params.presence_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=self,
|
||||
device=device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
@@ -617,6 +631,9 @@ class ScheduleBatch:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
||||
]
|
||||
else:
|
||||
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
||||
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
@@ -648,12 +665,12 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
@@ -674,12 +691,12 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
@@ -721,7 +738,8 @@ class ScheduleBatch:
|
||||
] = 1
|
||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
logits = self.penalizer_orchestrator.apply(logits)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
@@ -754,6 +772,8 @@ class ScheduleBatch:
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
|
||||
@@ -392,10 +392,13 @@ def v1_generate_request(all_requests):
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": request.stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
@@ -722,10 +725,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
}
|
||||
|
||||
@@ -162,6 +162,9 @@ class CompletionRequest(BaseModel):
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
ignore_eos: Optional[bool] = False
|
||||
min_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
@@ -259,6 +262,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
min_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
||||
13
python/sglang/srt/sampling/penaltylib/__init__.py
Normal file
13
python/sglang/srt/sampling/penaltylib/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .orchestrator import BatchedPenalizerOrchestrator
|
||||
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
||||
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
||||
|
||||
__all__ = [
|
||||
"BatchedFrequencyPenalizer",
|
||||
"BatchedMinNewTokensPenalizer",
|
||||
"BatchedPresencePenalizer",
|
||||
"BatchedRepetitionPenalizer",
|
||||
"BatchedPenalizerOrchestrator",
|
||||
]
|
||||
353
python/sglang/srt/sampling/penaltylib/orchestrator.py
Normal file
353
python/sglang/srt/sampling/penaltylib/orchestrator.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ReqLike:
|
||||
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BatchLike:
|
||||
reqs: typing.List[_ReqLike]
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchedPenalizerOrchestrator:
|
||||
batch: _BatchLike
|
||||
device: str
|
||||
vocab_size: int
|
||||
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
batch: _BatchLike,
|
||||
device: str,
|
||||
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self.device = device
|
||||
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.prepare_if_required()
|
||||
|
||||
self.cumulate_input_tokens(
|
||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||
)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
|
||||
def batch_size(self):
|
||||
return self.batch.batch_size()
|
||||
|
||||
def cumulate_input_tokens(
|
||||
self,
|
||||
input_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Feed the input tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||
|
||||
def cumulate_output_tokens(
|
||||
self,
|
||||
output_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Feed the output tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the penalizers to the logits.
|
||||
Note that it may apply the penalizers in-place.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The logits to apply the penalizers to.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The logits after applying the penalizers.
|
||||
"""
|
||||
for penalizer in self.penalizers.values():
|
||||
logits = penalizer.apply(logits)
|
||||
|
||||
return logits
|
||||
|
||||
def filter(
|
||||
self,
|
||||
indices_to_keep: typing.List[int],
|
||||
indices_tensor_to_keep: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Filter the penalizers based on the indices to keep in the batch.
|
||||
|
||||
Args:
|
||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
"""
|
||||
empty_indices = len(indices_to_keep) == 0
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
if not penalizer.is_required() or empty_indices:
|
||||
penalizer.teardown()
|
||||
else:
|
||||
# create tensor index only when it's needed
|
||||
if indices_tensor_to_keep is None:
|
||||
indices_tensor_to_keep = torch.tensor(
|
||||
indices_to_keep, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
penalizer.filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
|
||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||
"""
|
||||
Merge the penalizers of another orchestrator into this one.
|
||||
|
||||
Args:
|
||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||
"""
|
||||
if self.vocab_size != their.vocab_size:
|
||||
raise ValueError(
|
||||
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
|
||||
)
|
||||
|
||||
for Penalizer, their_penalizer in their.penalizers.items():
|
||||
if Penalizer not in self.penalizers:
|
||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||
|
||||
self.penalizers[Penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _TokenIDs:
|
||||
"""
|
||||
A class that wraps token IDs to provide additional utility functions to penalizers.
|
||||
|
||||
Attributes:
|
||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
|
||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
||||
cached_counts: torch.Tensor = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: BatchedPenalizerOrchestrator,
|
||||
token_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
if not isinstance(token_ids[0], torch.Tensor):
|
||||
token_ids = [
|
||||
torch.tensor(
|
||||
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
||||
)
|
||||
for ids in token_ids
|
||||
]
|
||||
|
||||
self.token_ids = token_ids
|
||||
|
||||
def occurrence_count(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The occurrence count tensor.
|
||||
"""
|
||||
if self.cached_counts is not None:
|
||||
return self.cached_counts
|
||||
|
||||
token_ids = self.token_ids
|
||||
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
|
||||
# needs to be long to be used as index in scatter_add
|
||||
if token_ids.dtype != torch.int64:
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
|
||||
return self.cached_counts
|
||||
|
||||
|
||||
class _BatchedPenalizer(abc.ABC):
|
||||
"""
|
||||
An abstract class for a batched penalizer.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
_is_prepared: bool = False
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
def is_prepared(self) -> bool:
|
||||
return self._is_prepared
|
||||
|
||||
def is_required(self) -> bool:
|
||||
return self._is_required()
|
||||
|
||||
def prepare(self):
|
||||
if not self.is_prepared():
|
||||
self._prepare()
|
||||
self._is_prepared = True
|
||||
|
||||
def prepare_if_required(self):
|
||||
if self.is_required():
|
||||
self.prepare()
|
||||
|
||||
def teardown(self):
|
||||
if self.is_prepared():
|
||||
self._teardown()
|
||||
self._is_prepared = False
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_prepared():
|
||||
return logits
|
||||
|
||||
return self._apply(logits=logits)
|
||||
|
||||
def filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
|
||||
def merge(self, their: "_BatchedPenalizer"):
|
||||
if not self.is_prepared() and not their.is_prepared():
|
||||
return
|
||||
|
||||
self.prepare()
|
||||
their.prepare()
|
||||
self._merge(their)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_required(self) -> bool:
|
||||
"""
|
||||
Check if the penalizer is required to be prepared.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _prepare(self):
|
||||
"""
|
||||
Prepare the penalizer.
|
||||
Usually, this is where the penalizer initializes its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _teardown(self):
|
||||
"""
|
||||
Tear down the penalizer.
|
||||
Usually, this is where the penalizer frees its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the input tokens.
|
||||
Orchestrator will call this function to feed the input tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the output tokens.
|
||||
Orchestrator will call this function to feed the output tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the penalizer to the logits.
|
||||
Penalizers can modify the logits in-place if needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _merge(self, their: "_BatchedPenalizer"):
|
||||
"""
|
||||
Merge the penalizer with another penalizer.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,80 @@
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
frequency_penalties: torch.Tensor = None
|
||||
cumulated_frequency_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_frequency_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.frequency_penalties
|
||||
del self.cumulated_frequency_penalties
|
||||
|
||||
self.frequency_penalties = None
|
||||
self.cumulated_frequency_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.cumulated_frequency_penalties += (
|
||||
self.frequency_penalties * output_ids.occurrence_count()
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_frequency_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Min new tokens penalizer penalizes tokens based on the length of the output.
|
||||
"""
|
||||
|
||||
min_new_tokens: torch.Tensor = None
|
||||
stop_token_penalties: torch.Tensor = None
|
||||
len_output_tokens: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.min_new_tokens = torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
).unsqueeze_(1)
|
||||
|
||||
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=[
|
||||
torch.tensor(
|
||||
data=list(
|
||||
req.sampling_params.stop_token_ids
|
||||
| {req.tokenizer.eos_token_id}
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.stop_token_penalties = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_stop_token_ids,
|
||||
src=torch.full_like(
|
||||
input=padded_stop_token_ids,
|
||||
dtype=torch.float32,
|
||||
fill_value=float("-inf"),
|
||||
device=self.orchestrator.device,
|
||||
),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
|
||||
self.len_output_tokens = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), 1),
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.min_new_tokens
|
||||
del self.stop_token_penalties
|
||||
del self.len_output_tokens
|
||||
|
||||
self.min_new_tokens = None
|
||||
self.stop_token_penalties = None
|
||||
self.len_output_tokens = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.len_output_tokens += 1
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
||||
logits[mask] += self.stop_token_penalties[mask]
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||
|
||||
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
||||
self.min_new_tokens = torch.cat(
|
||||
[self.min_new_tokens, their.min_new_tokens], dim=0
|
||||
)
|
||||
self.stop_token_penalties = torch.cat(
|
||||
[self.stop_token_penalties, their.stop_token_penalties], dim=0
|
||||
)
|
||||
self.len_output_tokens = torch.cat(
|
||||
[self.len_output_tokens, their.len_output_tokens], dim=0
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
presence_penalties: torch.Tensor = None
|
||||
cumulated_presence_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_presence_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.presence_penalties
|
||||
del self.cumulated_presence_penalties
|
||||
|
||||
self.presence_penalties = None
|
||||
self.cumulated_presence_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_presence_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -0,0 +1,83 @@
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
||||
"""
|
||||
|
||||
repetition_penalties: torch.Tensor = None
|
||||
cumulated_repetition_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.repetition_penalty != 1.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[1.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.repetition_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_repetition_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.repetition_penalties
|
||||
del self.cumulated_repetition_penalties
|
||||
|
||||
self.repetition_penalties = None
|
||||
self.cumulated_repetition_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
mask = input_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
logits > 0,
|
||||
logits / self.cumulated_repetition_penalties,
|
||||
logits * self.cumulated_repetition_penalties,
|
||||
)
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
||||
self.repetition_penalties = torch.cat(
|
||||
[self.repetition_penalties, their.repetition_penalties], dim=0
|
||||
)
|
||||
self.cumulated_repetition_penalties = torch.cat(
|
||||
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -24,12 +24,15 @@ class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 128,
|
||||
min_new_tokens: int = 0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
@@ -42,8 +45,11 @@ class SamplingParams:
|
||||
self.top_k = top_k
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.stop_strs = stop
|
||||
self.stop_token_ids = {*stop_token_ids}
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
@@ -80,11 +86,26 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
||||
)
|
||||
if not 0.0 <= self.repetition_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be in (0, 2], got "
|
||||
f"{self.repetition_penalty}."
|
||||
)
|
||||
if not 0 <= self.min_new_tokens:
|
||||
raise ValueError(
|
||||
f"min_new_tokens must be in (0, max_new_tokens], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.max_new_tokens is not None:
|
||||
if self.max_new_tokens < 0:
|
||||
raise ValueError(
|
||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||
)
|
||||
if not self.min_new_tokens <= self.max_new_tokens:
|
||||
raise ValueError(
|
||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
|
||||
337
python/sglang/test/srt/sampling/penaltylib/utils.py
Normal file
337
python/sglang/test/srt/sampling/penaltylib/utils.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import typing
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
_BatchLike,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockSamplingParams:
|
||||
frequency_penalty: float = 0.0
|
||||
min_new_tokens: int = 0
|
||||
stop_token_ids: typing.List[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockTokenizer:
|
||||
eos_token_id: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockReq:
|
||||
origin_input_ids: typing.List[int]
|
||||
sampling_params: MockSamplingParams
|
||||
tokenizer: MockTokenizer
|
||||
|
||||
|
||||
class StepType(enum.Enum):
|
||||
INPUT = "input"
|
||||
OUTPUT = "output"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Step:
|
||||
type: StepType
|
||||
token_ids: typing.List[int]
|
||||
expected_tensors: typing.Dict[str, torch.Tensor]
|
||||
# assume initial logits are all 1
|
||||
expected_logits: torch.Tensor
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Subject:
|
||||
sampling_params: MockSamplingParams
|
||||
# first step must be input, which will be converted to Req
|
||||
steps: typing.List[Step]
|
||||
eos_token_id: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.steps[0].type != StepType.INPUT:
|
||||
raise ValueError("First step must be input")
|
||||
|
||||
# each steps should have the same expected_tensors.keys()
|
||||
for i in range(1, len(self.steps)):
|
||||
if self.tensor_keys(i) != self.tensor_keys():
|
||||
raise ValueError(
|
||||
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
||||
)
|
||||
|
||||
def tensor_keys(self, i: int = 0) -> typing.Set[str]:
|
||||
return set(self.steps[i].expected_tensors.keys())
|
||||
|
||||
def to_req(self) -> MockReq:
|
||||
return MockReq(
|
||||
origin_input_ids=self.steps[0].token_ids,
|
||||
sampling_params=self.sampling_params,
|
||||
tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Case:
|
||||
enabled: bool
|
||||
test_subjects: typing.List[Subject]
|
||||
|
||||
def __post_init__(self):
|
||||
# each test_subjects.steps should have the same expected_tensors.keys()
|
||||
for i in range(1, len(self.test_subjects)):
|
||||
if self.tensor_keys(i) != self.tensor_keys():
|
||||
raise ValueError(
|
||||
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
||||
)
|
||||
|
||||
def tensor_keys(self, i: int = 0) -> typing.List[str]:
|
||||
return set(self.test_subjects[i].tensor_keys())
|
||||
|
||||
|
||||
class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
Penalizer: typing.Type[_BatchedPenalizer]
|
||||
device = "cuda"
|
||||
vocab_size = 5
|
||||
|
||||
enabled: Subject = None
|
||||
disabled: Subject = None
|
||||
|
||||
def setUp(self):
|
||||
if self.__class__ == BaseBatchedPenalizerTest:
|
||||
self.skipTest("Base class for penalizer tests")
|
||||
|
||||
self.create_test_subjects()
|
||||
self.create_test_cases()
|
||||
|
||||
def tensor(self, data, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Shortcut to create a tensor with device=self.device.
|
||||
"""
|
||||
return torch.tensor(data, **kwargs, device=self.device)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_test_cases(self):
|
||||
self.test_cases = [
|
||||
Case(enabled=True, test_subjects=[self.enabled]),
|
||||
Case(enabled=False, test_subjects=[self.disabled]),
|
||||
Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
|
||||
]
|
||||
|
||||
def _create_penalizer(
|
||||
self, case: Case
|
||||
) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
||||
orchestrator = BatchedPenalizerOrchestrator(
|
||||
vocab_size=self.vocab_size,
|
||||
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
||||
device=self.device,
|
||||
Penalizers={self.Penalizer},
|
||||
)
|
||||
|
||||
return orchestrator, orchestrator.penalizers[self.Penalizer]
|
||||
|
||||
def test_is_required(self):
|
||||
for case in self.test_cases:
|
||||
with self.subTest(case=case):
|
||||
_, penalizer = self._create_penalizer(case)
|
||||
self.assertEqual(case.enabled, penalizer.is_required())
|
||||
|
||||
def test_prepare(self):
|
||||
for case in self.test_cases:
|
||||
with self.subTest(case=case):
|
||||
orchestrator, penalizer = self._create_penalizer(case)
|
||||
self.assertEqual(case.enabled, penalizer.is_prepared())
|
||||
|
||||
if case.enabled:
|
||||
for key, tensor in {
|
||||
key: torch.cat(
|
||||
tensors=[
|
||||
subject.steps[0].expected_tensors[key]
|
||||
for subject in case.test_subjects
|
||||
],
|
||||
)
|
||||
for key in case.tensor_keys()
|
||||
}.items():
|
||||
torch.testing.assert_close(
|
||||
actual=getattr(penalizer, key),
|
||||
expected=tensor,
|
||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||
)
|
||||
|
||||
actual = orchestrator.apply(
|
||||
torch.ones(
|
||||
size=(len(case.test_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
expected = torch.cat(
|
||||
tensors=[
|
||||
subject.steps[0].expected_logits
|
||||
for subject in case.test_subjects
|
||||
],
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
actual=actual,
|
||||
expected=expected,
|
||||
msg=f"logits\nactual={actual}\nexpected={expected}",
|
||||
)
|
||||
|
||||
def test_teardown(self):
|
||||
for case in self.test_cases:
|
||||
with self.subTest(case=case):
|
||||
_, penalizer = self._create_penalizer(case)
|
||||
penalizer.teardown()
|
||||
|
||||
for key in case.test_subjects[0].steps[0].expected_tensors.keys():
|
||||
self.assertIsNone(getattr(penalizer, key, None))
|
||||
|
||||
def test_filter(self):
|
||||
for case in self.test_cases:
|
||||
with self.subTest(case=case):
|
||||
orchestrator, penalizer = self._create_penalizer(case)
|
||||
|
||||
indices_to_keep = [0]
|
||||
orchestrator.filter(indices_to_keep=indices_to_keep)
|
||||
|
||||
filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
|
||||
|
||||
if penalizer.is_required():
|
||||
self.assertTrue(penalizer.is_prepared())
|
||||
for key, tensor in {
|
||||
key: torch.cat(
|
||||
tensors=[
|
||||
subject.steps[0].expected_tensors[key]
|
||||
for subject in filtered_subjects
|
||||
],
|
||||
)
|
||||
for key in case.tensor_keys()
|
||||
}.items():
|
||||
torch.testing.assert_close(
|
||||
actual=getattr(penalizer, key),
|
||||
expected=tensor,
|
||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||
)
|
||||
|
||||
actual_logits = orchestrator.apply(
|
||||
torch.ones(
|
||||
size=(len(filtered_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
filtered_expected_logits = torch.cat(
|
||||
tensors=[
|
||||
subject.steps[0].expected_logits
|
||||
for subject in filtered_subjects
|
||||
],
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
actual=actual_logits,
|
||||
expected=filtered_expected_logits,
|
||||
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
||||
)
|
||||
|
||||
def test_merge_enabled_with_disabled(self):
|
||||
enabled_test_case = self.test_cases[0]
|
||||
disabled_test_case = self.test_cases[1]
|
||||
|
||||
orchestrator, penalizer = self._create_penalizer(enabled_test_case)
|
||||
theirs, _ = self._create_penalizer(disabled_test_case)
|
||||
|
||||
orchestrator.merge(theirs)
|
||||
|
||||
for key, tensor in {
|
||||
key: torch.cat(
|
||||
tensors=[
|
||||
enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
||||
disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
||||
],
|
||||
)
|
||||
for key in enabled_test_case.tensor_keys()
|
||||
}.items():
|
||||
torch.testing.assert_close(
|
||||
actual=getattr(penalizer, key),
|
||||
expected=tensor,
|
||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||
)
|
||||
|
||||
def test_cumulate_apply_repeat(self):
|
||||
for case in self.test_cases:
|
||||
with self.subTest(case=case):
|
||||
orchestrator, penalizer = self._create_penalizer(case)
|
||||
|
||||
max_step = max(len(subject.steps) for subject in case.test_subjects)
|
||||
for i in range(1, max_step):
|
||||
orchestrator.filter(
|
||||
indices_to_keep=[
|
||||
j
|
||||
for j, subject in enumerate(case.test_subjects)
|
||||
if i < len(subject.steps)
|
||||
]
|
||||
)
|
||||
|
||||
filtered_subjects = [
|
||||
subject
|
||||
for subject in case.test_subjects
|
||||
if i < len(subject.steps)
|
||||
]
|
||||
|
||||
inputs: typing.List[typing.List[int]] = []
|
||||
outputs: typing.List[typing.List[int]] = []
|
||||
for subject in filtered_subjects:
|
||||
step = subject.steps[i]
|
||||
if step.type == StepType.INPUT:
|
||||
inputs.append(step.token_ids)
|
||||
outputs.append([])
|
||||
else:
|
||||
inputs.append([])
|
||||
outputs.append(step.token_ids)
|
||||
|
||||
if any(inputs):
|
||||
orchestrator.cumulate_input_tokens(inputs)
|
||||
|
||||
if any(outputs):
|
||||
orchestrator.cumulate_output_tokens(outputs)
|
||||
|
||||
if penalizer.is_required():
|
||||
self.assertTrue(penalizer.is_prepared())
|
||||
for key, tensor in {
|
||||
key: torch.cat(
|
||||
tensors=[
|
||||
subject.steps[i].expected_tensors[key]
|
||||
for subject in filtered_subjects
|
||||
],
|
||||
)
|
||||
for key in case.tensor_keys()
|
||||
}.items():
|
||||
torch.testing.assert_close(
|
||||
actual=getattr(penalizer, key),
|
||||
expected=tensor,
|
||||
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
||||
)
|
||||
|
||||
actual_logits = orchestrator.apply(
|
||||
torch.ones(
|
||||
size=(len(filtered_subjects), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
filtered_expected_logits = torch.cat(
|
||||
tensors=[
|
||||
subject.steps[i].expected_logits
|
||||
for subject in filtered_subjects
|
||||
],
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
actual=actual_logits,
|
||||
expected=filtered_expected_logits,
|
||||
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
||||
)
|
||||
@@ -482,7 +482,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
p.terminate()
|
||||
time.sleep(5)
|
||||
print(
|
||||
"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
|
||||
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user