[Feature] Add Logit Bias (#6579)
Co-authored-by: Cinjon Resnick <cinjon.resnick@gmail.com>
This commit is contained in:
@@ -582,6 +582,7 @@ def v1_generate_request(
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
)
|
||||
return_logprobs.append(request.logprobs is not None)
|
||||
@@ -1219,6 +1220,7 @@ def v1_chat_generate_request(
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
from sglang.srt.utils import merge_bias_tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
# Handle logit bias
|
||||
logit_bias: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
logit_bias = None
|
||||
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
||||
logit_bias = torch.zeros(len(reqs), vocab_size, device=device)
|
||||
for i, r in enumerate(reqs):
|
||||
if r.sampling_params.logit_bias is not None:
|
||||
for key, value in r.sampling_params.logit_bias.items():
|
||||
logit_bias[i, int(key)] = value
|
||||
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = (
|
||||
batch.enable_custom_logit_processor # check the flag first.
|
||||
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
|
||||
custom_params=custom_params,
|
||||
custom_logit_processor=merged_custom_logit_processor,
|
||||
device=device,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
||||
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(keep_indices_device)
|
||||
|
||||
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
|
||||
value = getattr(self, item, None)
|
||||
setattr(self, item, value[keep_indices_device])
|
||||
|
||||
if self.logit_bias is not None:
|
||||
self.logit_bias = self.logit_bias[keep_indices_device]
|
||||
|
||||
def _filter_batch_custom_logit_processor(
|
||||
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
||||
):
|
||||
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
|
||||
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
# Merge logit bias
|
||||
self.logit_bias = merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
||||
)
|
||||
|
||||
@@ -52,6 +52,7 @@ class SamplingParams:
|
||||
no_stop_trim: bool = False,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
stream_interval: Optional[int] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
) -> None:
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.stop_strs = stop
|
||||
@@ -78,6 +79,7 @@ class SamplingParams:
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.custom_params = custom_params
|
||||
self.stream_interval = stream_interval
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
# Process some special cases
|
||||
if 0 <= self.temperature < _SAMPLING_EPS:
|
||||
|
||||
@@ -2210,6 +2210,45 @@ class Withable(Generic[T]):
|
||||
self._value = None
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
rhs: Optional[torch.Tensor],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: float,
|
||||
):
|
||||
"""Merge two bias tensors for batch merging.
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side tensor
|
||||
rhs: Right-hand side tensor
|
||||
bs1: Batch size of left-hand side tensor
|
||||
bs2: Batch size of right-hand side tensor
|
||||
device: Device to place the merged tensor on
|
||||
default: Default value for missing tensor elements
|
||||
|
||||
Returns:
|
||||
Merged tensor or None if both inputs are None
|
||||
"""
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
|
||||
if lhs is not None and rhs is not None:
|
||||
return torch.cat([lhs, rhs])
|
||||
else:
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
|
||||
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
|
||||
import huggingface_hub as hf
|
||||
|
||||
|
||||
Reference in New Issue
Block a user