Support min-p sampling (#1167)

This commit is contained in:
intervitens
2024-08-22 01:49:32 +03:00
committed by GitHub
parent d6aeb9fa15
commit 068e9eae55
7 changed files with 58 additions and 9 deletions

View File

@@ -130,6 +130,7 @@ class CompiledFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
@@ -145,6 +146,7 @@ class CompiledFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
@@ -160,6 +162,7 @@ class CompiledFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
@@ -178,6 +181,7 @@ class CompiledFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)

View File

@@ -663,6 +663,7 @@ class StreamExecutor:
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"ignore_eos",

View File

@@ -22,6 +22,7 @@ class SglSamplingParams:
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
@@ -42,6 +43,7 @@ class SglSamplingParams:
self.temperature,
self.top_p,
self.top_k,
self.min_p,
self.frequency_penalty,
self.presence_penalty,
self.ignore_eos,
@@ -114,6 +116,7 @@ class SglSamplingParams:
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
@@ -149,6 +152,7 @@ class SglFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
@@ -169,6 +173,7 @@ class SglFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
@@ -190,6 +195,7 @@ class SglFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
@@ -228,6 +234,7 @@ class SglFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
@@ -408,6 +415,7 @@ class SglGen(SglExpr):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
@@ -428,6 +436,7 @@ class SglGen(SglExpr):
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,