Support min-p sampling (#1167)
This commit is contained in:
@@ -45,6 +45,8 @@ temperature: float = 1.0,
|
|||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
# Top-k sampling
|
# Top-k sampling
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
|
# Min-p sampling
|
||||||
|
min_p: float = 0.0,
|
||||||
# Whether to ignore EOS token.
|
# Whether to ignore EOS token.
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
# Whether to skip the special tokens during detokenization.
|
# Whether to skip the special tokens during detokenization.
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ def gen(
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
@@ -103,6 +104,7 @@ def gen(
|
|||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
|
min_p,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
ignore_eos,
|
ignore_eos,
|
||||||
@@ -123,6 +125,7 @@ def gen_int(
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
@@ -139,6 +142,7 @@ def gen_int(
|
|||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
|
min_p,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
ignore_eos,
|
ignore_eos,
|
||||||
@@ -159,6 +163,7 @@ def gen_string(
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
@@ -175,6 +180,7 @@ def gen_string(
|
|||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
|
min_p,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
ignore_eos,
|
ignore_eos,
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ class CompiledFunction:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
|
min_p: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
backend=None,
|
backend=None,
|
||||||
@@ -145,6 +146,7 @@ class CompiledFunction:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
)
|
)
|
||||||
@@ -160,6 +162,7 @@ class CompiledFunction:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
|
min_p: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
backend=None,
|
backend=None,
|
||||||
@@ -178,6 +181,7 @@ class CompiledFunction:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -663,6 +663,7 @@ class StreamExecutor:
|
|||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
"top_k",
|
"top_k",
|
||||||
|
"min_p",
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"ignore_eos",
|
"ignore_eos",
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class SglSamplingParams:
|
|||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
top_p: float = 1.0
|
top_p: float = 1.0
|
||||||
top_k: int = -1 # -1 means disable
|
top_k: int = -1 # -1 means disable
|
||||||
|
min_p: float = 0.0
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
@@ -42,6 +43,7 @@ class SglSamplingParams:
|
|||||||
self.temperature,
|
self.temperature,
|
||||||
self.top_p,
|
self.top_p,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
|
self.min_p,
|
||||||
self.frequency_penalty,
|
self.frequency_penalty,
|
||||||
self.presence_penalty,
|
self.presence_penalty,
|
||||||
self.ignore_eos,
|
self.ignore_eos,
|
||||||
@@ -114,6 +116,7 @@ class SglSamplingParams:
|
|||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
|
"min_p": self.min_p,
|
||||||
"frequency_penalty": self.frequency_penalty,
|
"frequency_penalty": self.frequency_penalty,
|
||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
"ignore_eos": self.ignore_eos,
|
"ignore_eos": self.ignore_eos,
|
||||||
@@ -149,6 +152,7 @@ class SglFunction:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
|
min_p: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
@@ -169,6 +173,7 @@ class SglFunction:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
@@ -190,6 +195,7 @@ class SglFunction:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
|
min_p: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
@@ -228,6 +234,7 @@ class SglFunction:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
@@ -408,6 +415,7 @@ class SglGen(SglExpr):
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
|
min_p: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
@@ -428,6 +436,7 @@ class SglGen(SglExpr):
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
|||||||
@@ -21,7 +21,12 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
from flashinfer.sampling import (
|
||||||
|
min_p_sampling_from_probs,
|
||||||
|
top_k_renorm_prob,
|
||||||
|
top_k_top_p_sampling_from_probs,
|
||||||
|
top_p_renorm_prob,
|
||||||
|
)
|
||||||
from vllm.distributed import get_tensor_model_parallel_group
|
from vllm.distributed import get_tensor_model_parallel_group
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
@@ -339,6 +344,7 @@ class ScheduleBatch:
|
|||||||
temperatures: torch.Tensor = None
|
temperatures: torch.Tensor = None
|
||||||
top_ps: torch.Tensor = None
|
top_ps: torch.Tensor = None
|
||||||
top_ks: torch.Tensor = None
|
top_ks: torch.Tensor = None
|
||||||
|
min_ps: torch.Tensor = None
|
||||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
|
|
||||||
@@ -403,6 +409,9 @@ class ScheduleBatch:
|
|||||||
self.top_ks = torch.tensor(
|
self.top_ks = torch.tensor(
|
||||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||||
)
|
)
|
||||||
|
self.min_ps = torch.tensor(
|
||||||
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||||
@@ -701,6 +710,7 @@ class ScheduleBatch:
|
|||||||
"temperatures",
|
"temperatures",
|
||||||
"top_ps",
|
"top_ps",
|
||||||
"top_ks",
|
"top_ks",
|
||||||
|
"min_ps",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
]:
|
]:
|
||||||
self_val = getattr(self, item, None)
|
self_val = getattr(self, item, None)
|
||||||
@@ -730,6 +740,7 @@ class ScheduleBatch:
|
|||||||
"temperatures",
|
"temperatures",
|
||||||
"top_ps",
|
"top_ps",
|
||||||
"top_ks",
|
"top_ks",
|
||||||
|
"min_ps",
|
||||||
]:
|
]:
|
||||||
self_val = getattr(self, item, None)
|
self_val = getattr(self, item, None)
|
||||||
other_val = getattr(other, item, None)
|
other_val = getattr(other, item, None)
|
||||||
@@ -780,13 +791,20 @@ class ScheduleBatch:
|
|||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
)
|
)
|
||||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
if self.min_ps.any():
|
||||||
probs, uniform_samples, self.top_ks, self.top_ps
|
probs = top_k_renorm_prob(probs, self.top_ks)
|
||||||
)
|
probs = top_p_renorm_prob(probs, self.top_ps)
|
||||||
|
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, self.min_ps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, self.top_ks, self.top_ps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Here we provide a slower fallback implementation.
|
# Here we provide a slower fallback implementation.
|
||||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs, self.top_ks, self.top_ps
|
probs, self.top_ks, self.top_ps, self.min_ps
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.all(success):
|
if not torch.all(success):
|
||||||
@@ -810,17 +828,22 @@ class ScheduleBatch:
|
|||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_sampling_from_probs_torch(
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
probs: torch.Tensor,
|
||||||
|
top_ks: torch.Tensor,
|
||||||
|
top_ps: torch.Tensor,
|
||||||
|
min_ps: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||||
probs_sort[
|
probs_sort[
|
||||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||||
>= top_ks.view(-1, 1)
|
>= top_ks.view(-1, 1)
|
||||||
] = 0.0
|
] = 0.0
|
||||||
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||||
try:
|
try:
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class SamplingParams:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
|
min_p: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
@@ -42,6 +43,7 @@ class SamplingParams:
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
self.min_p = min_p
|
||||||
self.frequency_penalty = frequency_penalty
|
self.frequency_penalty = frequency_penalty
|
||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.repetition_penalty = repetition_penalty
|
self.repetition_penalty = repetition_penalty
|
||||||
@@ -69,6 +71,8 @@ class SamplingParams:
|
|||||||
)
|
)
|
||||||
if not 0.0 < self.top_p <= 1.0:
|
if not 0.0 < self.top_p <= 1.0:
|
||||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||||
|
if not 0.0 <= self.min_p <= 1.0:
|
||||||
|
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
||||||
if self.top_k < -1 or self.top_k == 0:
|
if self.top_k < -1 or self.top_k == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
||||||
|
|||||||
Reference in New Issue
Block a user