From ab7875941b34200529eddd1fb950efa981dc3866 Mon Sep 17 00:00:00 2001 From: Juwan Yoo Date: Thu, 8 Aug 2024 04:21:08 -0700 Subject: [PATCH] feat: frequency, min_new_tokens, presence, and repetition penalties (#973) --- docs/en/sampling_params.md | 273 ++++++++++++++ python/sglang/bench_serving.py | 21 +- python/sglang/srt/managers/schedule_batch.py | 52 ++- python/sglang/srt/openai_api/adapter.py | 6 + python/sglang/srt/openai_api/protocol.py | 6 + .../srt/sampling/penaltylib/__init__.py | 13 + .../srt/sampling/penaltylib/orchestrator.py | 353 ++++++++++++++++++ .../penalizers/frequency_penalty.py | 80 ++++ .../penaltylib/penalizers/min_new_tokens.py | 105 ++++++ .../penaltylib/penalizers/presence_penalty.py | 79 ++++ .../penalizers/repetition_penalty.py | 83 ++++ python/sglang/srt/sampling_params.py | 21 ++ .../test/srt/sampling/penaltylib/utils.py | 337 +++++++++++++++++ python/sglang/test/test_utils.py | 2 +- test/srt/run_suite.py | 11 + .../penalizers/test_frequency_penalty.py | 80 ++++ .../penalizers/test_min_new_tokens.py | 152 ++++++++ .../penalizers/test_presence_penalty.py | 80 ++++ .../penalizers/test_repetition_penalty.py | 87 +++++ .../test_srt_endpoint_with_penalizers.py | 75 ++++ 20 files changed, 1898 insertions(+), 18 deletions(-) create mode 100644 python/sglang/srt/sampling/penaltylib/__init__.py create mode 100644 python/sglang/srt/sampling/penaltylib/orchestrator.py create mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py create mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py create mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py create mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py create mode 100644 python/sglang/test/srt/sampling/penaltylib/utils.py create mode 100644 test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py create mode 100644 test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py create mode 100644 test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py create mode 100644 test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py create mode 100644 test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md index 782bb1fb6..5f1cdece6 100644 --- a/docs/en/sampling_params.md +++ b/docs/en/sampling_params.md @@ -36,6 +36,9 @@ The `sampling_params` follows this format max_new_tokens: int = 128, # Stop when hitting any of the strings in this list. stop: Optional[Union[str, List[str]]] = None, +# Stop when hitting any of the token_ids in this list. Could be useful when mixed with +# `min_new_tokens`. +stop_token_ids: Optional[List[int]] = [], # Sampling temperature temperature: float = 1.0, # Top-p sampling @@ -52,6 +55,27 @@ spaces_between_special_tokens: bool = True, regex: Optional[str] = None, # Do parallel sampling and return `n` outputs. n: int = 1, + +## Penalties. See [Performance Implications on Penalties] section below for more informations. + +# Float that penalizes new tokens based on their frequency in the generated text so far. +# Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to +# repeat tokens. Must be -2 <= value <= 2. Setting to 0 (default) will disable this penalty. +frequency_penalty: float = 0.0, +# Float that penalizes new tokens based on whether they appear in the generated text so far. +# Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat +# tokens. Must be -2 <= value <= 2. Setting to 0 (default) will disable this penalty. +presence_penalty: float = 0.0, +# Float that penalizes new tokens based on whether they appear in the prompt and the generated text +# so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to +# repeat tokens. Must be 0 <= value <= 2. Setting to 1 (default) will disable this penalty. +repetition_penalty: float = 1.0, +# Guides inference to generate at least this number of tokens by penalizing logits of tokenizer's +# EOS token and `stop_token_ids` to -inf, until the output token reaches given length. +# Note that any of the `stop` string can be generated before reaching `min_new_tokens`, as it is +# difficult to infer the correct token ID by given `stop` strings. +# Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty. +min_new_tokens: int = 0, ``` ## Examples @@ -142,3 +166,252 @@ print(response.json()) The `image_data` can be a file name, a URL, or a base64 encoded string. See also `python/sglang/srt/utils.py:load_image`. Streaming is supported in a similar manner as [above](#streaming). + +## Performance Implications on Penalties + +While you can apply penalties by supplying relevant `sampling_params`, this comes with some drawbacks. + +These drawbacks will be applied to every single requests in the same batch, as penalizers also applies in batch. + +### Latency + +While we try to compute penalty algorithms through CUDA, it is still additional computation on top of the basic sampling logic. For detailed overhead, we recommend you to run your own benchmarks, but you can find samples below to get a glimpse. + +### Memory + +Since we compute penalty algorithms through CUDA, the logic stores relevant parameters on GPU. This is usually in a scale of `vocab_size` multiplied by `running_requests`. + +You can run your own benchmark with desired parameters on your own hardware to make sure it's not OOMing before using. + +Tuning `--mem-fraction-static` and/or `--max-running-requests` will help. See [here](hyperparameter_tuning.md#minor-tune---max-prefill-tokens---mem-fraction-static---max-running-requests) for more information. + +### Benchmarks + +All the benchmarks below were ran on NVIDIA H100 SXM5. + +
+ +#### Baseline + +Measured at [dc9d06d886151707f97d0b78095df9de262fd3c9](https://github.com/sgl-project/sglang/commit/dc9d06d886151707f97d0b78095df9de262fd3c9). + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 66.11 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 775118 +Request throughput (req/s): 45.38 +Input token throughput (tok/s): 5727.04 +Output token throughput (tok/s): 11732.16 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 40881.94 +Median E2E Latency (ms): 43967.10 +---------------Time to First Token---------------- +Mean TTFT (ms): 19884.75 +Median TTFT (ms): 14226.56 +P99 TTFT (ms): 47738.97 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 91.96 +Median TPOT (ms): 90.11 +P99 TPOT (ms): 308.54 +---------------Inter-token Latency---------------- +Mean ITL (ms): 174.54 +Median ITL (ms): 58.56 +P99 ITL (ms): 440.18 +================================================== +``` + +#### All Together + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "frequency_penalty": 1.1, + "presence_penalty": 1.1, + "repetition_penalty": 0.1, + "min_new_tokens": 5 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 78.35 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 774756 +Request throughput (req/s): 38.29 +Input token throughput (tok/s): 4832.86 +Output token throughput (tok/s): 9900.39 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 49017.68 +Median E2E Latency (ms): 52825.70 +---------------Time to First Token---------------- +Mean TTFT (ms): 23892.60 +Median TTFT (ms): 18895.47 +P99 TTFT (ms): 57426.01 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 114.54 +Median TPOT (ms): 107.27 +P99 TPOT (ms): 293.31 +---------------Inter-token Latency---------------- +Mean ITL (ms): 205.68 +Median ITL (ms): 73.97 +P99 ITL (ms): 453.86 +================================================== +``` + +#### Frequency Penalty + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "frequency_penalty": 1.1 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 72.72 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 774955 +Request throughput (req/s): 41.26 +Input token throughput (tok/s): 5206.84 +Output token throughput (tok/s): 10666.51 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 45445.56 +Median E2E Latency (ms): 48960.39 +---------------Time to First Token---------------- +Mean TTFT (ms): 22363.16 +Median TTFT (ms): 17125.02 +P99 TTFT (ms): 52920.95 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 104.71 +Median TPOT (ms): 98.30 +P99 TPOT (ms): 268.06 +---------------Inter-token Latency---------------- +Mean ITL (ms): 191.60 +Median ITL (ms): 67.83 +P99 ITL (ms): 455.46 +================================================== +``` + +#### Presence Penalty + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "presence_penalty": 1.1 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 72.04 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 775210 +Request throughput (req/s): 41.64 +Input token throughput (tok/s): 5255.98 +Output token throughput (tok/s): 10767.18 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 44926.61 +Median E2E Latency (ms): 48302.88 +---------------Time to First Token---------------- +Mean TTFT (ms): 22095.39 +Median TTFT (ms): 16740.93 +P99 TTFT (ms): 52554.03 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 103.54 +Median TPOT (ms): 97.37 +P99 TPOT (ms): 271.86 +---------------Inter-token Latency---------------- +Mean ITL (ms): 189.86 +Median ITL (ms): 68.45 +P99 ITL (ms): 447.11 +================================================== +``` + +#### Repetition Penalty + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "repetition_penalty": 0.1 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 74.54 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 766008 +Request throughput (req/s): 40.24 +Input token throughput (tok/s): 5079.36 +Output token throughput (tok/s): 10405.35 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 46530.38 +Median E2E Latency (ms): 50302.65 +---------------Time to First Token---------------- +Mean TTFT (ms): 22603.47 +Median TTFT (ms): 17167.08 +P99 TTFT (ms): 54497.85 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 117.59 +Median TPOT (ms): 101.79 +P99 TPOT (ms): 320.04 +---------------Inter-token Latency---------------- +Mean ITL (ms): 195.26 +Median ITL (ms): 69.51 +P99 ITL (ms): 433.86 +================================================== +``` + +#### Min New Tokens + +The min new tokens penalizer computes until generation process reaches given `min_new_tokens`. + +Dislike other penalizers, setting this to higher value will have more latency implications. + +``` +$ python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 --extra-request-body '{ + "min_new_tokens": 5 +}' + +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Successful requests: 3000 +Benchmark duration (s): 66.94 +Total input tokens: 378633 +Total generated tokens: 775651 +Total generated tokens (retokenized): 775220 +Request throughput (req/s): 44.81 +Input token throughput (tok/s): 5656.13 +Output token throughput (tok/s): 11586.90 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 41888.55 +Median E2E Latency (ms): 45354.16 +---------------Time to First Token---------------- +Mean TTFT (ms): 20866.91 +Median TTFT (ms): 16219.79 +P99 TTFT (ms): 49263.91 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 97.05 +Median TPOT (ms): 89.76 +P99 TPOT (ms): 233.50 +---------------Inter-token Latency---------------- +Mean ITL (ms): 179.17 +Median ITL (ms): 55.08 +P99 ITL (ms): 409.12 +================================================== +``` + +
diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 253aab355..cc2406846 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -24,7 +24,7 @@ import warnings from argparse import ArgumentParser from dataclasses import dataclass, field from datetime import datetime -from typing import AsyncGenerator, List, Optional, Tuple, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union import aiohttp import numpy as np @@ -47,6 +47,7 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + extra_request_body: Dict[str, Any] @dataclass @@ -84,6 +85,7 @@ async def async_request_trt_llm( "stream": True, "min_length": request_func_input.output_len, "end_id": 1048576, + **request_func_input.extra_request_body, } if args.disable_ignore_eos: del payload["min_length"] @@ -154,6 +156,7 @@ async def async_request_openai_completions( "max_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} @@ -542,6 +545,7 @@ async def benchmark( request_rate: float, disable_tqdm: bool, enable_multi: bool, + extra_request_body: Dict[str, Any], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -556,6 +560,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + extra_request_body=extra_request_body, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -578,6 +583,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + extra_request_body=extra_request_body, ) tasks.append( asyncio.create_task( @@ -746,6 +752,10 @@ def fire(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + if args.port is None: args.port = { "sglang": 30000, @@ -838,6 +848,7 @@ def fire(args: argparse.Namespace): request_rate=rate, disable_tqdm=args.disable_tqdm, enable_multi=args.multi, + extra_request_body=extra_request_body, ) ) else: @@ -851,6 +862,7 @@ def fire(args: argparse.Namespace): request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, enable_multi=args.multi, + extra_request_body=extra_request_body, ) ) @@ -976,6 +988,13 @@ if __name__ == "__main__": action="store_true", help="Disable ignoring EOS.", ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) set_ulimit() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 714777dc1..4f89ba3b9 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -24,6 +24,7 @@ import numpy as np import torch from flashinfer.sampling import top_k_top_p_sampling_from_probs +import sglang.srt.sampling.penaltylib as penaltylib from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap @@ -222,8 +223,9 @@ class Req: ) return + last_token_id = self.output_ids[-1] if ( - self.output_ids[-1] == self.tokenizer.eos_token_id + last_token_id == self.tokenizer.eos_token_id and not self.sampling_params.ignore_eos ): self.finished_reason = FINISH_MATCHED_TOKEN( @@ -231,6 +233,10 @@ class Req: ) return + if last_token_id in self.sampling_params.stop_token_ids: + self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) + return + if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] @@ -321,8 +327,7 @@ class ScheduleBatch: temperatures: torch.Tensor = None top_ps: torch.Tensor = None top_ks: torch.Tensor = None - frequency_penalties: torch.Tensor = None - presence_penalties: torch.Tensor = None + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None logit_bias: torch.Tensor = None @classmethod @@ -386,15 +391,24 @@ class ScheduleBatch: self.top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device ) - self.frequency_penalties = torch.tensor( - [r.sampling_params.frequency_penalty for r in reqs], - dtype=torch.float, - device=device, - ) - self.presence_penalties = torch.tensor( - [r.sampling_params.presence_penalty for r in reqs], - dtype=torch.float, + + # Each penalizers will do nothing if they evaluate themselves as not required by looking at + # the sampling_params of the requests (See {_is_required()} of each penalizers). So this + # should not add hefty computation overhead other than simple checks. + # + # While we choose not to even create the class instances if they are not required, this + # could add additional complexity to the {ScheduleBatch} class, especially we need to + # handle {filter_batch()} and {merge()} cases as well. + self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=self, device=device, + Penalizers={ + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, + }, ) # Handle logit bias but only allocate when needed @@ -617,6 +631,9 @@ class ScheduleBatch: input_ids = [ r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs ] + else: + self.penalizer_orchestrator.cumulate_input_tokens(input_ids) + self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) @@ -648,12 +665,12 @@ class ScheduleBatch: self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) + self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + for item in [ "temperatures", "top_ps", "top_ks", - "frequency_penalties", - "presence_penalties", "logit_bias", ]: self_val = getattr(self, item, None) @@ -674,12 +691,12 @@ class ScheduleBatch: self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + for item in [ "temperatures", "top_ps", "top_ks", - "frequency_penalties", - "presence_penalties", ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) @@ -721,7 +738,8 @@ class ScheduleBatch: ] = 1 logits[i].masked_fill_(~allowed_mask, float("-inf")) - # TODO(lmzheng): apply penalty + logits = self.penalizer_orchestrator.apply(logits) + probs = torch.softmax(logits, dim=-1) if not global_server_args_dict["disable_flashinfer_sampling"]: @@ -754,6 +772,8 @@ class ScheduleBatch: req.regex_fsm_state, batch_next_token_ids_cpu[i] ) + self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) + return batch_next_token_ids diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5e21d67e4..c12138391 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -392,10 +392,13 @@ def v1_generate_request(all_requests): { "temperature": request.temperature, "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, "stop": request.stop, + "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, "regex": request.regex, "n": request.n, "ignore_eos": request.ignore_eos, @@ -722,10 +725,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): { "temperature": request.temperature, "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, "stop": stop, + "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, "regex": request.regex, "n": request.n, } diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 2910dd5cd..75f0a1aab 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -162,6 +162,9 @@ class CompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None ignore_eos: Optional[bool] = False + min_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel): @@ -259,6 +262,9 @@ class ChatCompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + min_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) class ChatMessage(BaseModel): diff --git a/python/sglang/srt/sampling/penaltylib/__init__.py b/python/sglang/srt/sampling/penaltylib/__init__.py new file mode 100644 index 000000000..43fff0fca --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/__init__.py @@ -0,0 +1,13 @@ +from .orchestrator import BatchedPenalizerOrchestrator +from .penalizers.frequency_penalty import BatchedFrequencyPenalizer +from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer +from .penalizers.presence_penalty import BatchedPresencePenalizer +from .penalizers.repetition_penalty import BatchedRepetitionPenalizer + +__all__ = [ + "BatchedFrequencyPenalizer", + "BatchedMinNewTokensPenalizer", + "BatchedPresencePenalizer", + "BatchedRepetitionPenalizer", + "BatchedPenalizerOrchestrator", +] diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py new file mode 100644 index 000000000..969a5d820 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -0,0 +1,353 @@ +import abc +import dataclasses +import typing + +import torch + + +@dataclasses.dataclass +class _ReqLike: + origin_input_ids: typing.Union[torch.Tensor, typing.List[int]] + + +@dataclasses.dataclass +class _BatchLike: + reqs: typing.List[_ReqLike] + + def batch_size(self): + return len(self.reqs) + + +class BatchedPenalizerOrchestrator: + batch: _BatchLike + device: str + vocab_size: int + penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"] + + def __init__( + self, + vocab_size: int, + batch: _BatchLike, + device: str, + Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]], + ): + self.vocab_size = vocab_size + self.batch = batch + self.device = device + + self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} + + for penalizer in self.penalizers.values(): + penalizer.prepare_if_required() + + self.cumulate_input_tokens( + input_ids=[req.origin_input_ids for req in self.reqs()] + ) + + def reqs(self): + return self.batch.reqs + + def batch_size(self): + return self.batch.batch_size() + + def cumulate_input_tokens( + self, + input_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + """ + Feed the input tokens to the penalizers. + + Args: + input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens. + """ + token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) + + for penalizer in self.penalizers.values(): + penalizer.cumulate_input_tokens(input_ids=token_ids) + + def cumulate_output_tokens( + self, + output_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + """ + Feed the output tokens to the penalizers. + + Args: + output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. + """ + token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) + + for penalizer in self.penalizers.values(): + penalizer.cumulate_output_tokens(output_ids=token_ids) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply the penalizers to the logits. + Note that it may apply the penalizers in-place. + + Args: + logits (torch.Tensor): The logits to apply the penalizers to. + + Returns: + torch.Tensor: The logits after applying the penalizers. + """ + for penalizer in self.penalizers.values(): + logits = penalizer.apply(logits) + + return logits + + def filter( + self, + indices_to_keep: typing.List[int], + indices_tensor_to_keep: torch.Tensor = None, + ): + """ + Filter the penalizers based on the indices to keep in the batch. + + Args: + indices_to_keep (typing.List[int]): List of indices to keep in the batch. + indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. + """ + empty_indices = len(indices_to_keep) == 0 + + for penalizer in self.penalizers.values(): + if not penalizer.is_required() or empty_indices: + penalizer.teardown() + else: + # create tensor index only when it's needed + if indices_tensor_to_keep is None: + indices_tensor_to_keep = torch.tensor( + indices_to_keep, dtype=torch.int32, device=self.device + ) + + penalizer.filter( + indices_to_keep=indices_to_keep, + indices_tensor_to_keep=indices_tensor_to_keep, + ) + + def merge(self, their: "BatchedPenalizerOrchestrator"): + """ + Merge the penalizers of another orchestrator into this one. + + Args: + their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. + """ + if self.vocab_size != their.vocab_size: + raise ValueError( + f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}" + ) + + for Penalizer, their_penalizer in their.penalizers.items(): + if Penalizer not in self.penalizers: + raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") + + self.penalizers[Penalizer].merge(their_penalizer) + + +class _TokenIDs: + """ + A class that wraps token IDs to provide additional utility functions to penalizers. + + Attributes: + orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. + token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs. + cached_counts (torch.Tensor): The cached occurrence count tensor. + """ + + orchestrator: BatchedPenalizerOrchestrator + token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]] + cached_counts: torch.Tensor = None + + def __init__( + self, + orchestrator: BatchedPenalizerOrchestrator, + token_ids: typing.Union[ + typing.List[torch.Tensor], typing.List[typing.List[int]] + ], + ): + self.orchestrator = orchestrator + + if not isinstance(token_ids[0], torch.Tensor): + token_ids = [ + torch.tensor( + data=ids, dtype=torch.int64, device=self.orchestrator.device + ) + for ids in token_ids + ] + + self.token_ids = token_ids + + def occurrence_count(self) -> torch.Tensor: + """ + Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch. + + Returns: + torch.Tensor: The occurrence count tensor. + """ + if self.cached_counts is not None: + return self.cached_counts + + token_ids = self.token_ids + + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.unsqueeze(1) + + # needs to be long to be used as index in scatter_add + if token_ids.dtype != torch.int64: + token_ids = token_ids.to(torch.int64) + + padded_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=token_ids, + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + + self.cached_counts = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.int64, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_token_ids, + src=torch.ones_like(padded_token_ids), + )[ + :, : self.orchestrator.vocab_size + ] + + return self.cached_counts + + +class _BatchedPenalizer(abc.ABC): + """ + An abstract class for a batched penalizer. + """ + + orchestrator: BatchedPenalizerOrchestrator + _is_prepared: bool = False + + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): + self.orchestrator = orchestrator + + def is_prepared(self) -> bool: + return self._is_prepared + + def is_required(self) -> bool: + return self._is_required() + + def prepare(self): + if not self.is_prepared(): + self._prepare() + self._is_prepared = True + + def prepare_if_required(self): + if self.is_required(): + self.prepare() + + def teardown(self): + if self.is_prepared(): + self._teardown() + self._is_prepared = False + + def cumulate_input_tokens(self, input_ids: _TokenIDs): + if not self.is_prepared(): + return + + self._cumulate_input_tokens(input_ids=input_ids) + + def cumulate_output_tokens(self, output_ids: _TokenIDs): + if not self.is_prepared(): + return + + self._cumulate_output_tokens(output_ids=output_ids) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.is_prepared(): + return logits + + return self._apply(logits=logits) + + def filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + if not self.is_prepared(): + return + + self._filter( + indices_to_keep=indices_to_keep, + indices_tensor_to_keep=indices_tensor_to_keep, + ) + + def merge(self, their: "_BatchedPenalizer"): + if not self.is_prepared() and not their.is_prepared(): + return + + self.prepare() + their.prepare() + self._merge(their) + + @abc.abstractmethod + def _is_required(self) -> bool: + """ + Check if the penalizer is required to be prepared. + """ + pass + + @abc.abstractmethod + def _prepare(self): + """ + Prepare the penalizer. + Usually, this is where the penalizer initializes its tensors. + """ + pass + + @abc.abstractmethod + def _teardown(self): + """ + Tear down the penalizer. + Usually, this is where the penalizer frees its tensors. + """ + pass + + @abc.abstractmethod + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + """ + Cumulate the input tokens. + Orchestrator will call this function to feed the input tokens to the penalizer. + """ + pass + + @abc.abstractmethod + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + """ + Cumulate the output tokens. + Orchestrator will call this function to feed the output tokens to the penalizer. + """ + pass + + @abc.abstractmethod + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply the penalizer to the logits. + Penalizers can modify the logits in-place if needed. + """ + pass + + @abc.abstractmethod + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + """ + Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. + """ + pass + + @abc.abstractmethod + def _merge(self, their: "_BatchedPenalizer"): + """ + Merge the penalizer with another penalizer. + """ + pass diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py new file mode 100644 index 000000000..178cb54b2 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py @@ -0,0 +1,80 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedFrequencyPenalizer(_BatchedPenalizer): + """ + Frequency penalizer penalizes tokens based on their frequency in the output. + """ + + frequency_penalties: torch.Tensor = None + cumulated_frequency_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.frequency_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_frequency_penalties = ( + torch.tensor( + data=[0.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.frequency_penalties = ( + torch.tensor( + data=[ + req.sampling_params.frequency_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_frequency_penalties) + ) + + def _teardown(self): + del self.frequency_penalties + del self.cumulated_frequency_penalties + + self.frequency_penalties = None + self.cumulated_frequency_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.cumulated_frequency_penalties += ( + self.frequency_penalties * output_ids.occurrence_count() + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits -= self.cumulated_frequency_penalties + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] + self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedFrequencyPenalizer"): + self.frequency_penalties = torch.cat( + [self.frequency_penalties, their.frequency_penalties], dim=0 + ) + self.cumulated_frequency_penalties = torch.cat( + [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py new file mode 100644 index 000000000..c9e0f078e --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py @@ -0,0 +1,105 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedMinNewTokensPenalizer(_BatchedPenalizer): + """ + Min new tokens penalizer penalizes tokens based on the length of the output. + """ + + min_new_tokens: torch.Tensor = None + stop_token_penalties: torch.Tensor = None + len_output_tokens: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.min_new_tokens = torch.tensor( + data=[ + req.sampling_params.min_new_tokens for req in self.orchestrator.reqs() + ], + dtype=torch.int32, + device=self.orchestrator.device, + ).unsqueeze_(1) + + padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=[ + torch.tensor( + data=list( + req.sampling_params.stop_token_ids + | {req.tokenizer.eos_token_id} + ), + dtype=torch.int64, + device=self.orchestrator.device, + ) + for req in self.orchestrator.reqs() + ], + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + self.stop_token_penalties = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.float32, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_stop_token_ids, + src=torch.full_like( + input=padded_stop_token_ids, + dtype=torch.float32, + fill_value=float("-inf"), + device=self.orchestrator.device, + ), + )[ + :, : self.orchestrator.vocab_size + ] + + self.len_output_tokens = torch.zeros( + size=(self.orchestrator.batch_size(), 1), + dtype=torch.int32, + device=self.orchestrator.device, + ) + + def _teardown(self): + del self.min_new_tokens + del self.stop_token_penalties + del self.len_output_tokens + + self.min_new_tokens = None + self.stop_token_penalties = None + self.len_output_tokens = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.len_output_tokens += 1 + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits) + logits[mask] += self.stop_token_penalties[mask] + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] + self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] + self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] + + def _merge(self, their: "BatchedMinNewTokensPenalizer"): + self.min_new_tokens = torch.cat( + [self.min_new_tokens, their.min_new_tokens], dim=0 + ) + self.stop_token_penalties = torch.cat( + [self.stop_token_penalties, their.stop_token_penalties], dim=0 + ) + self.len_output_tokens = torch.cat( + [self.len_output_tokens, their.len_output_tokens], dim=0 + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py new file mode 100644 index 000000000..0593fddc9 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py @@ -0,0 +1,79 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedPresencePenalizer(_BatchedPenalizer): + """ + Presence penalizer penalizes tokens based on their presence in the output. + """ + + presence_penalties: torch.Tensor = None + cumulated_presence_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.presence_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_presence_penalties = ( + torch.tensor( + data=[0.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.presence_penalties = ( + torch.tensor( + data=[ + req.sampling_params.presence_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_presence_penalties) + ) + + def _teardown(self): + del self.presence_penalties + del self.cumulated_presence_penalties + + self.presence_penalties = None + self.cumulated_presence_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + pass + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + mask = output_ids.occurrence_count() > 0 + self.cumulated_presence_penalties[mask] = self.presence_penalties[mask] + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits -= self.cumulated_presence_penalties + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] + self.cumulated_presence_penalties = self.cumulated_presence_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedPresencePenalizer"): + self.presence_penalties = torch.cat( + [self.presence_penalties, their.presence_penalties], dim=0 + ) + self.cumulated_presence_penalties = torch.cat( + [self.cumulated_presence_penalties, their.cumulated_presence_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py new file mode 100644 index 000000000..ea32addc2 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -0,0 +1,83 @@ +import typing + +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + + +class BatchedRepetitionPenalizer(_BatchedPenalizer): + """ + Repetition penalizer penalizes tokens based on their repetition in the input and output. + """ + + repetition_penalties: torch.Tensor = None + cumulated_repetition_penalties: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.repetition_penalty != 1.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_repetition_penalties = ( + torch.tensor( + data=[1.0 for _ in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .repeat(1, self.orchestrator.vocab_size) + ) + + self.repetition_penalties = ( + torch.tensor( + data=[ + req.sampling_params.repetition_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + .unsqueeze_(1) + .expand_as(self.cumulated_repetition_penalties) + ) + + def _teardown(self): + del self.repetition_penalties + del self.cumulated_repetition_penalties + + self.repetition_penalties = None + self.cumulated_repetition_penalties = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + mask = input_ids.occurrence_count() > 0 + self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + mask = output_ids.occurrence_count() > 0 + self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + return torch.where( + logits > 0, + logits / self.cumulated_repetition_penalties, + logits * self.cumulated_repetition_penalties, + ) + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] + self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ + indices_tensor_to_keep + ] + + def _merge(self, their: "BatchedRepetitionPenalizer"): + self.repetition_penalties = torch.cat( + [self.repetition_penalties, their.repetition_penalties], dim=0 + ) + self.cumulated_repetition_penalties = torch.cat( + [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 89091b7ae..39774d9ac 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -24,12 +24,15 @@ class SamplingParams: def __init__( self, max_new_tokens: int = 128, + min_new_tokens: int = 0, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + repetition_penalty: float = 1.0, ignore_eos: bool = False, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, @@ -42,8 +45,11 @@ class SamplingParams: self.top_k = top_k self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty + self.repetition_penalty = repetition_penalty self.stop_strs = stop + self.stop_token_ids = {*stop_token_ids} self.max_new_tokens = max_new_tokens + self.min_new_tokens = min_new_tokens self.ignore_eos = ignore_eos self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens @@ -80,11 +86,26 @@ class SamplingParams: raise ValueError( "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." ) + if not 0.0 <= self.repetition_penalty <= 2.0: + raise ValueError( + "repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}." + ) + if not 0 <= self.min_new_tokens: + raise ValueError( + f"min_new_tokens must be in (0, max_new_tokens], got " + f"{self.min_new_tokens}." + ) if self.max_new_tokens is not None: if self.max_new_tokens < 0: raise ValueError( f"max_new_tokens must be at least 0, got {self.max_new_tokens}." ) + if not self.min_new_tokens <= self.max_new_tokens: + raise ValueError( + f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " + f"{self.min_new_tokens}." + ) def normalize(self, tokenizer): # Process stop strings diff --git a/python/sglang/test/srt/sampling/penaltylib/utils.py b/python/sglang/test/srt/sampling/penaltylib/utils.py new file mode 100644 index 000000000..b41eac32b --- /dev/null +++ b/python/sglang/test/srt/sampling/penaltylib/utils.py @@ -0,0 +1,337 @@ +import dataclasses +import enum +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.orchestrator import ( + BatchedPenalizerOrchestrator, + _BatchedPenalizer, + _BatchLike, +) + + +@dataclasses.dataclass +class MockSamplingParams: + frequency_penalty: float = 0.0 + min_new_tokens: int = 0 + stop_token_ids: typing.List[int] = None + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 + + +@dataclasses.dataclass +class MockTokenizer: + eos_token_id: int + + +@dataclasses.dataclass +class MockReq: + origin_input_ids: typing.List[int] + sampling_params: MockSamplingParams + tokenizer: MockTokenizer + + +class StepType(enum.Enum): + INPUT = "input" + OUTPUT = "output" + + +@dataclasses.dataclass +class Step: + type: StepType + token_ids: typing.List[int] + expected_tensors: typing.Dict[str, torch.Tensor] + # assume initial logits are all 1 + expected_logits: torch.Tensor + + +@dataclasses.dataclass +class Subject: + sampling_params: MockSamplingParams + # first step must be input, which will be converted to Req + steps: typing.List[Step] + eos_token_id: int = -1 + + def __post_init__(self): + if self.steps[0].type != StepType.INPUT: + raise ValueError("First step must be input") + + # each steps should have the same expected_tensors.keys() + for i in range(1, len(self.steps)): + if self.tensor_keys(i) != self.tensor_keys(): + raise ValueError( + f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}" + ) + + def tensor_keys(self, i: int = 0) -> typing.Set[str]: + return set(self.steps[i].expected_tensors.keys()) + + def to_req(self) -> MockReq: + return MockReq( + origin_input_ids=self.steps[0].token_ids, + sampling_params=self.sampling_params, + tokenizer=MockTokenizer(eos_token_id=self.eos_token_id), + ) + + +@dataclasses.dataclass +class Case: + enabled: bool + test_subjects: typing.List[Subject] + + def __post_init__(self): + # each test_subjects.steps should have the same expected_tensors.keys() + for i in range(1, len(self.test_subjects)): + if self.tensor_keys(i) != self.tensor_keys(): + raise ValueError( + f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}" + ) + + def tensor_keys(self, i: int = 0) -> typing.List[str]: + return set(self.test_subjects[i].tensor_keys()) + + +class BaseBatchedPenalizerTest(unittest.TestCase): + Penalizer: typing.Type[_BatchedPenalizer] + device = "cuda" + vocab_size = 5 + + enabled: Subject = None + disabled: Subject = None + + def setUp(self): + if self.__class__ == BaseBatchedPenalizerTest: + self.skipTest("Base class for penalizer tests") + + self.create_test_subjects() + self.create_test_cases() + + def tensor(self, data, **kwargs) -> torch.Tensor: + """ + Shortcut to create a tensor with device=self.device. + """ + return torch.tensor(data, **kwargs, device=self.device) + + def create_test_subjects(self) -> typing.List[Subject]: + raise NotImplementedError() + + def create_test_cases(self): + self.test_cases = [ + Case(enabled=True, test_subjects=[self.enabled]), + Case(enabled=False, test_subjects=[self.disabled]), + Case(enabled=True, test_subjects=[self.enabled, self.disabled]), + ] + + def _create_penalizer( + self, case: Case + ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]: + orchestrator = BatchedPenalizerOrchestrator( + vocab_size=self.vocab_size, + batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]), + device=self.device, + Penalizers={self.Penalizer}, + ) + + return orchestrator, orchestrator.penalizers[self.Penalizer] + + def test_is_required(self): + for case in self.test_cases: + with self.subTest(case=case): + _, penalizer = self._create_penalizer(case) + self.assertEqual(case.enabled, penalizer.is_required()) + + def test_prepare(self): + for case in self.test_cases: + with self.subTest(case=case): + orchestrator, penalizer = self._create_penalizer(case) + self.assertEqual(case.enabled, penalizer.is_prepared()) + + if case.enabled: + for key, tensor in { + key: torch.cat( + tensors=[ + subject.steps[0].expected_tensors[key] + for subject in case.test_subjects + ], + ) + for key in case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + actual = orchestrator.apply( + torch.ones( + size=(len(case.test_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ) + expected = torch.cat( + tensors=[ + subject.steps[0].expected_logits + for subject in case.test_subjects + ], + ) + torch.testing.assert_close( + actual=actual, + expected=expected, + msg=f"logits\nactual={actual}\nexpected={expected}", + ) + + def test_teardown(self): + for case in self.test_cases: + with self.subTest(case=case): + _, penalizer = self._create_penalizer(case) + penalizer.teardown() + + for key in case.test_subjects[0].steps[0].expected_tensors.keys(): + self.assertIsNone(getattr(penalizer, key, None)) + + def test_filter(self): + for case in self.test_cases: + with self.subTest(case=case): + orchestrator, penalizer = self._create_penalizer(case) + + indices_to_keep = [0] + orchestrator.filter(indices_to_keep=indices_to_keep) + + filtered_subjects = [case.test_subjects[i] for i in indices_to_keep] + + if penalizer.is_required(): + self.assertTrue(penalizer.is_prepared()) + for key, tensor in { + key: torch.cat( + tensors=[ + subject.steps[0].expected_tensors[key] + for subject in filtered_subjects + ], + ) + for key in case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + actual_logits = orchestrator.apply( + torch.ones( + size=(len(filtered_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ) + filtered_expected_logits = torch.cat( + tensors=[ + subject.steps[0].expected_logits + for subject in filtered_subjects + ], + ) + torch.testing.assert_close( + actual=actual_logits, + expected=filtered_expected_logits, + msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}", + ) + + def test_merge_enabled_with_disabled(self): + enabled_test_case = self.test_cases[0] + disabled_test_case = self.test_cases[1] + + orchestrator, penalizer = self._create_penalizer(enabled_test_case) + theirs, _ = self._create_penalizer(disabled_test_case) + + orchestrator.merge(theirs) + + for key, tensor in { + key: torch.cat( + tensors=[ + enabled_test_case.test_subjects[0].steps[0].expected_tensors[key], + disabled_test_case.test_subjects[0].steps[0].expected_tensors[key], + ], + ) + for key in enabled_test_case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + def test_cumulate_apply_repeat(self): + for case in self.test_cases: + with self.subTest(case=case): + orchestrator, penalizer = self._create_penalizer(case) + + max_step = max(len(subject.steps) for subject in case.test_subjects) + for i in range(1, max_step): + orchestrator.filter( + indices_to_keep=[ + j + for j, subject in enumerate(case.test_subjects) + if i < len(subject.steps) + ] + ) + + filtered_subjects = [ + subject + for subject in case.test_subjects + if i < len(subject.steps) + ] + + inputs: typing.List[typing.List[int]] = [] + outputs: typing.List[typing.List[int]] = [] + for subject in filtered_subjects: + step = subject.steps[i] + if step.type == StepType.INPUT: + inputs.append(step.token_ids) + outputs.append([]) + else: + inputs.append([]) + outputs.append(step.token_ids) + + if any(inputs): + orchestrator.cumulate_input_tokens(inputs) + + if any(outputs): + orchestrator.cumulate_output_tokens(outputs) + + if penalizer.is_required(): + self.assertTrue(penalizer.is_prepared()) + for key, tensor in { + key: torch.cat( + tensors=[ + subject.steps[i].expected_tensors[key] + for subject in filtered_subjects + ], + ) + for key in case.tensor_keys() + }.items(): + torch.testing.assert_close( + actual=getattr(penalizer, key), + expected=tensor, + msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", + ) + + actual_logits = orchestrator.apply( + torch.ones( + size=(len(filtered_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ) + filtered_expected_logits = torch.cat( + tensors=[ + subject.steps[i].expected_logits + for subject in filtered_subjects + ], + ) + torch.testing.assert_close( + actual=actual_logits, + expected=filtered_expected_logits, + msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}", + ) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 43f2730c7..c6212dc39 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -482,7 +482,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): p.terminate() time.sleep(5) print( - "\nTimeout after {timeout_per_file} seconds when running {filename}\n" + f"\nTimeout after {timeout_per_file} seconds when running {filename}\n" ) return False diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f993b7e8b..edb8db316 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -11,9 +11,20 @@ suites = { "test_chunked_prefill.py", "test_torch_compile.py", "models/test_causal_models.py", + "sampling/penaltylib", ], + "sampling/penaltylib": glob.glob( + "sampling/penaltylib/**/test_*.py", recursive=True + ), } +for target_suite_name, target_tests in suites.items(): + for suite_name, tests in suites.items(): + if suite_name == target_suite_name: + continue + if target_suite_name in tests: + tests.remove(target_suite_name) + tests.extend(target_tests) if __name__ == "__main__": arg_parser = argparse.ArgumentParser() diff --git a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py new file mode 100644 index 000000000..b659a04fc --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py @@ -0,0 +1,80 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import ( + BatchedFrequencyPenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + +FREQUENCY_PENALTY = 0.12 + + +class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest): + Penalizer = BatchedFrequencyPenalizer + + def _create_subject(self, frequency_penalty: float) -> Subject: + return Subject( + sampling_params=MockSamplingParams( + frequency_penalty=frequency_penalty, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "frequency_penalties": self.tensor( + [[frequency_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_frequency_penalties": self.tensor( + [[0.0] * self.vocab_size], dtype=torch.float32 + ), + }, + expected_logits=self.tensor( + [[1] * self.vocab_size], dtype=torch.float32 + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[1, 2, 2], + expected_tensors={ + "frequency_penalties": self.tensor( + [[frequency_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_frequency_penalties": self.tensor( + [ + [ + frequency_penalty * i if i in {1, 2} else 0.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + }, + expected_logits=self.tensor( + [ + [ + 1.0 - frequency_penalty * i if i in {1, 2} else 1.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(frequency_penalty=FREQUENCY_PENALTY) + self.disabled = self._create_subject(frequency_penalty=0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py new file mode 100644 index 000000000..1984aafe5 --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py @@ -0,0 +1,152 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import ( + BatchedMinNewTokensPenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + +MIN_NEW_TOKENS = 2 +EOS_TOKEN_ID = 4 +STOP_TOKEN_ID = 3 + +ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID} + + +class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest): + Penalizer = BatchedMinNewTokensPenalizer + + def _create_subject(self, min_new_tokens: int) -> Subject: + return Subject( + eos_token_id=EOS_TOKEN_ID, + sampling_params=MockSamplingParams( + min_new_tokens=min_new_tokens, + stop_token_ids={STOP_TOKEN_ID}, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "min_new_tokens": self.tensor( + [[min_new_tokens]], dtype=torch.int32 + ), + "stop_token_penalties": self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ), + "len_output_tokens": self.tensor([[0]], dtype=torch.int32), + }, + expected_logits=( + self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ) + if min_new_tokens > 0 + else torch.ones( + (1, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[0], + expected_tensors={ + "min_new_tokens": self.tensor( + [[min_new_tokens]], dtype=torch.int32 + ), + "stop_token_penalties": self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ), + "len_output_tokens": self.tensor([[1]], dtype=torch.int32), + }, + expected_logits=( + self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ) + if min_new_tokens > 1 + else torch.ones( + (1, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[0], + expected_tensors={ + "min_new_tokens": self.tensor( + [[min_new_tokens]], dtype=torch.int32 + ), + "stop_token_penalties": self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ), + "len_output_tokens": self.tensor([[2]], dtype=torch.int32), + }, + expected_logits=( + self.tensor( + [ + [ + float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 + for i in range(self.vocab_size) + ] + ], + dtype=torch.float32, + ) + if min_new_tokens > 2 + else torch.ones( + (1, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS) + self.disabled = self._create_subject(min_new_tokens=0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py new file mode 100644 index 000000000..30cb2b9a0 --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py @@ -0,0 +1,80 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import ( + BatchedPresencePenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + +PRESENCE_PENALTY = 0.12 + + +class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest): + Penalizer = BatchedPresencePenalizer + + def _create_subject(self, presence_penalty: float) -> Subject: + return Subject( + sampling_params=MockSamplingParams( + presence_penalty=presence_penalty, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "presence_penalties": self.tensor( + [[presence_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_presence_penalties": self.tensor( + [[0.0] * self.vocab_size], dtype=torch.float32 + ), + }, + expected_logits=self.tensor( + [[1] * self.vocab_size], dtype=torch.float32 + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[1, 2, 2], + expected_tensors={ + "presence_penalties": self.tensor( + [[presence_penalty] * self.vocab_size], dtype=torch.float32 + ), + "cumulated_presence_penalties": self.tensor( + [ + [ + presence_penalty if i in {1, 2} else 0.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + }, + expected_logits=self.tensor( + [ + [ + 1.0 - presence_penalty if i in {1, 2} else 1.0 + for i in range(self.vocab_size) + ], + ], + dtype=torch.float32, + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(presence_penalty=PRESENCE_PENALTY) + self.disabled = self._create_subject(presence_penalty=0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py new file mode 100644 index 000000000..e3751c14a --- /dev/null +++ b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py @@ -0,0 +1,87 @@ +import typing +import unittest + +import torch + +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + BatchedRepetitionPenalizer, +) +from sglang.test.srt.sampling.penaltylib.utils import ( + BaseBatchedPenalizerTest, + MockSamplingParams, + Step, + StepType, + Subject, +) + +REPETITION_PENALTY = 2.0 + + +class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest): + Penalizer = BatchedRepetitionPenalizer + + def _create_subject(self, repetition_penalty: float) -> Subject: + l = 1.0 / repetition_penalty + return Subject( + sampling_params=MockSamplingParams( + repetition_penalty=repetition_penalty, + ), + steps=[ + Step( + type=StepType.INPUT, + token_ids=[0, 1, 2], + expected_tensors={ + "repetition_penalties": self.tensor( + [[repetition_penalty] * self.vocab_size], + dtype=torch.float32, + ), + "cumulated_repetition_penalties": ( + self.tensor( + [[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32 + ) + if repetition_penalty != 1.0 + else self.tensor( + [[1.0] * self.vocab_size], dtype=torch.float32 + ) + ), + }, + expected_logits=( + self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32) + if repetition_penalty != 1.0 + else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32) + ), + ), + Step( + type=StepType.OUTPUT, + token_ids=[0, 1, 3], + expected_tensors={ + "repetition_penalties": self.tensor( + [[repetition_penalty] * self.vocab_size], + dtype=torch.float32, + ), + "cumulated_repetition_penalties": ( + self.tensor( + [[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32 + ) + if repetition_penalty != 1.0 + else self.tensor( + [[1.0] * self.vocab_size], dtype=torch.float32 + ) + ), + }, + expected_logits=( + self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32) + if repetition_penalty != 1.0 + else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32) + ), + ), + ], + ) + + def create_test_subjects(self) -> typing.List[Subject]: + self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY) + self.disabled = self._create_subject(repetition_penalty=1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py new file mode 100644 index 000000000..5ea6af7cc --- /dev/null +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -0,0 +1,75 @@ +import json +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestBatchPenalizerE2E(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = f"http://127.0.0.1:{8157}" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=( + "--random-seed", + "0", + ), + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode( + self, + return_logprob=True, + top_logprobs_num=5, + return_text=True, + n=1, + **sampling_params, + ): + response = requests.post( + self.base_url + "/generate", + json={ + # prompt that is supposed to generate < 32 tokens + "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + "sampling_params": { + "max_new_tokens": 32, + "n": n, + **sampling_params, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + + def test_default_values(self): + self.run_decode() + + def test_frequency_penalty(self): + self.run_decode(frequency_penalty=2) + + def test_min_new_tokens(self): + self.run_decode(min_new_tokens=16) + + def test_presence_penalty(self): + self.run_decode(presence_penalty=2) + + def test_repetition_penalty(self): + self.run_decode(repetition_penalty=2) + + +if __name__ == "__main__": + unittest.main(warnings="ignore")