From 068e9eae55daf2ca1666cfa64ad66139b02fa623 Mon Sep 17 00:00:00 2001 From: intervitens <155717317+intervitens@users.noreply.github.com> Date: Thu, 22 Aug 2024 01:49:32 +0300 Subject: [PATCH] Support min-p sampling (#1167) --- docs/en/sampling_params.md | 2 + python/sglang/api.py | 6 +++ python/sglang/lang/compiler.py | 4 ++ python/sglang/lang/interpreter.py | 1 + python/sglang/lang/ir.py | 9 +++++ python/sglang/srt/managers/schedule_batch.py | 41 +++++++++++++++----- python/sglang/srt/sampling_params.py | 4 ++ 7 files changed, 58 insertions(+), 9 deletions(-) diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md index 5f1cdece6..7d866e692 100644 --- a/docs/en/sampling_params.md +++ b/docs/en/sampling_params.md @@ -45,6 +45,8 @@ temperature: float = 1.0, top_p: float = 1.0, # Top-k sampling top_k: int = -1, +# Min-p sampling +min_p: float = 0.0, # Whether to ignore EOS token. ignore_eos: bool = False, # Whether to skip the special tokens during detokenization. diff --git a/python/sglang/api.py b/python/sglang/api.py index 887ffce76..3a2f747be 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -66,6 +66,7 @@ def gen( temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -103,6 +104,7 @@ def gen( temperature, top_p, top_k, + min_p, frequency_penalty, presence_penalty, ignore_eos, @@ -123,6 +125,7 @@ def gen_int( temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -139,6 +142,7 @@ def gen_int( temperature, top_p, top_k, + min_p, frequency_penalty, presence_penalty, ignore_eos, @@ -159,6 +163,7 @@ def gen_string( temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -175,6 +180,7 @@ def gen_string( temperature, top_p, top_k, + min_p, frequency_penalty, presence_penalty, ignore_eos, diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py index 95af04adb..5e1b411fc 100644 --- a/python/sglang/lang/compiler.py +++ b/python/sglang/lang/compiler.py @@ -130,6 +130,7 @@ class CompiledFunction: temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, backend=None, @@ -145,6 +146,7 @@ class CompiledFunction: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ) @@ -160,6 +162,7 @@ class CompiledFunction: temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, backend=None, @@ -178,6 +181,7 @@ class CompiledFunction: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 844c9d062..306d280c7 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -663,6 +663,7 @@ class StreamExecutor: "temperature", "top_p", "top_k", + "min_p", "frequency_penalty", "presence_penalty", "ignore_eos", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 9db5f2719..199a7ac7a 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -22,6 +22,7 @@ class SglSamplingParams: temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 # -1 means disable + min_p: float = 0.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 ignore_eos: bool = False @@ -42,6 +43,7 @@ class SglSamplingParams: self.temperature, self.top_p, self.top_k, + self.min_p, self.frequency_penalty, self.presence_penalty, self.ignore_eos, @@ -114,6 +116,7 @@ class SglSamplingParams: "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, + "min_p": self.min_p, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "ignore_eos": self.ignore_eos, @@ -149,6 +152,7 @@ class SglFunction: temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ignore_eos: bool = False, @@ -169,6 +173,7 @@ class SglFunction: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, @@ -190,6 +195,7 @@ class SglFunction: temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ignore_eos: bool = False, @@ -228,6 +234,7 @@ class SglFunction: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, @@ -408,6 +415,7 @@ class SglGen(SglExpr): temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, + min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, @@ -428,6 +436,7 @@ class SglGen(SglExpr): temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1437d0e6c..9abce6f9b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -21,7 +21,12 @@ from typing import List, Optional, Union import torch import torch.distributed as dist -from flashinfer.sampling import top_k_top_p_sampling_from_probs +from flashinfer.sampling import ( + min_p_sampling_from_probs, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, +) from vllm.distributed import get_tensor_model_parallel_group import sglang.srt.sampling.penaltylib as penaltylib @@ -339,6 +344,7 @@ class ScheduleBatch: temperatures: torch.Tensor = None top_ps: torch.Tensor = None top_ks: torch.Tensor = None + min_ps: torch.Tensor = None penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None logit_bias: torch.Tensor = None @@ -403,6 +409,9 @@ class ScheduleBatch: self.top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device ) + self.min_ps = torch.tensor( + [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device + ) # Each penalizers will do nothing if they evaluate themselves as not required by looking at # the sampling_params of the requests (See {_is_required()} of each penalizers). So this @@ -701,6 +710,7 @@ class ScheduleBatch: "temperatures", "top_ps", "top_ks", + "min_ps", "logit_bias", ]: self_val = getattr(self, item, None) @@ -730,6 +740,7 @@ class ScheduleBatch: "temperatures", "top_ps", "top_ks", + "min_ps", ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) @@ -780,13 +791,20 @@ class ScheduleBatch: uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device ) - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, uniform_samples, self.top_ks, self.top_ps - ) + if self.min_ps.any(): + probs = top_k_renorm_prob(probs, self.top_ks) + probs = top_p_renorm_prob(probs, self.top_ps) + batch_next_token_ids, success = min_p_sampling_from_probs( + probs, uniform_samples, self.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, uniform_samples, self.top_ks, self.top_ps + ) else: # Here we provide a slower fallback implementation. - batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch( - probs, self.top_ks, self.top_ps + batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( + probs, self.top_ks, self.top_ps, self.min_ps ) if not torch.all(success): @@ -810,17 +828,22 @@ class ScheduleBatch: return batch_next_token_ids -def top_k_top_p_sampling_from_probs_torch( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor +def top_k_top_p_min_p_sampling_from_probs_torch( + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + min_ps: torch.Tensor, ): - """A top-k and top-k sampling implementation with native pytorch operations.""" + """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) + min_p_thresholds = probs_sort[:, 0] * min_ps probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 probs_sort[ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) ] = 0.0 + probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) try: sampled_index = torch.multinomial(probs_sort, num_samples=1) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 712827d79..c30717dd7 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -30,6 +30,7 @@ class SamplingParams: temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repetition_penalty: float = 1.0, @@ -42,6 +43,7 @@ class SamplingParams: self.temperature = temperature self.top_p = top_p self.top_k = top_k + self.min_p = min_p self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.repetition_penalty = repetition_penalty @@ -69,6 +71,8 @@ class SamplingParams: ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.top_k < -1 or self.top_k == 0: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."