Support min-p sampling (#1167)
This commit is contained in:
@@ -66,6 +66,7 @@ def gen(
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
@@ -103,6 +104,7 @@ def gen(
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
@@ -123,6 +125,7 @@ def gen_int(
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
@@ -139,6 +142,7 @@ def gen_int(
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
@@ -159,6 +163,7 @@ def gen_string(
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
@@ -175,6 +180,7 @@ def gen_string(
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
|
||||
@@ -130,6 +130,7 @@ class CompiledFunction:
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
@@ -145,6 +146,7 @@ class CompiledFunction:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
@@ -160,6 +162,7 @@ class CompiledFunction:
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
@@ -178,6 +181,7 @@ class CompiledFunction:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
@@ -663,6 +663,7 @@ class StreamExecutor:
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"ignore_eos",
|
||||
|
||||
@@ -22,6 +22,7 @@ class SglSamplingParams:
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
min_p: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
@@ -42,6 +43,7 @@ class SglSamplingParams:
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
self.min_p,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
self.ignore_eos,
|
||||
@@ -114,6 +116,7 @@ class SglSamplingParams:
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"min_p": self.min_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
@@ -149,6 +152,7 @@ class SglFunction:
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
@@ -169,6 +173,7 @@ class SglFunction:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
@@ -190,6 +195,7 @@ class SglFunction:
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
@@ -228,6 +234,7 @@ class SglFunction:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
@@ -408,6 +415,7 @@ class SglGen(SglExpr):
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
@@ -428,6 +436,7 @@ class SglGen(SglExpr):
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
|
||||
@@ -21,7 +21,12 @@ from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_group
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
@@ -339,6 +344,7 @@ class ScheduleBatch:
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
logit_bias: torch.Tensor = None
|
||||
|
||||
@@ -403,6 +409,9 @@ class ScheduleBatch:
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
self.min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
@@ -701,6 +710,7 @@ class ScheduleBatch:
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
@@ -730,6 +740,7 @@ class ScheduleBatch:
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
@@ -780,13 +791,20 @@ class ScheduleBatch:
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
if self.min_ps.any():
|
||||
probs = top_k_renorm_prob(probs, self.top_ks)
|
||||
probs = top_p_renorm_prob(probs, self.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
else:
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
||||
probs, self.top_ks, self.top_ps
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, self.top_ks, self.top_ps, self.min_ps
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
@@ -810,17 +828,22 @@ class ScheduleBatch:
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
top_ks: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
min_ps: torch.Tensor,
|
||||
):
|
||||
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||
>= top_ks.view(-1, 1)
|
||||
] = 0.0
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
|
||||
@@ -30,6 +30,7 @@ class SamplingParams:
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
@@ -42,6 +43,7 @@ class SamplingParams:
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
@@ -69,6 +71,8 @@ class SamplingParams:
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(
|
||||
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
||||
|
||||
Reference in New Issue
Block a user