feat: add thinking_budget (#6089)
This commit is contained in:
@@ -30,8 +30,13 @@ class SamplingBatchInfo:
|
||||
# Whether any request needs min_p sampling
|
||||
need_min_p_sampling: bool
|
||||
|
||||
# Use thinking_budget to truncate thinking
|
||||
num_thinking_tokens: Optional[torch.Tensor] = None
|
||||
think_end_ids: Optional[torch.Tensor] = None
|
||||
thinking_budgets: Optional[torch.Tensor] = None
|
||||
|
||||
# Masking tensors for grammar-guided structured outputs
|
||||
vocab_size: int
|
||||
vocab_size: int = 0
|
||||
grammars: Optional[List] = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
@@ -76,7 +81,22 @@ class SamplingBatchInfo:
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
|
||||
think_end_ids = torch.tensor(
|
||||
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
|
||||
dtype=torch.int64,
|
||||
).to(device, non_blocking=True)
|
||||
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
thinking_budgets = torch.tensor(
|
||||
[r.sampling_params.thinking_budget or -1 for r in reqs],
|
||||
dtype=torch.int64,
|
||||
).to(device, non_blocking=True)
|
||||
else:
|
||||
think_end_ids = None
|
||||
num_thinking_tokens = None
|
||||
thinking_budgets = None
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = (
|
||||
batch.enable_custom_logit_processor # check the flag first.
|
||||
@@ -132,6 +152,9 @@ class SamplingBatchInfo:
|
||||
top_ps=top_ps,
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
think_end_ids=think_end_ids,
|
||||
num_thinking_tokens=num_thinking_tokens,
|
||||
thinking_budgets=thinking_budgets,
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
@@ -146,6 +169,35 @@ class SamplingBatchInfo:
|
||||
def __len__(self):
|
||||
return len(self.temperatures)
|
||||
|
||||
def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
|
||||
has_budget = self.thinking_budgets > 0
|
||||
if not has_budget.any():
|
||||
return
|
||||
torch.where(
|
||||
has_budget,
|
||||
self.num_thinking_tokens + 1,
|
||||
self.num_thinking_tokens,
|
||||
out=self.num_thinking_tokens,
|
||||
)
|
||||
should_stop = has_budget & (
|
||||
self.num_thinking_tokens - 1 > self.thinking_budgets
|
||||
)
|
||||
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
|
||||
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
|
||||
if len(batch_indices) > 0:
|
||||
end_token_indices = self.think_end_ids[batch_indices]
|
||||
next_token_logits[batch_indices, end_token_indices] = 0.0
|
||||
|
||||
def update_thinking_budgets(self, next_token_ids: torch.Tensor):
|
||||
if not torch.any(self.thinking_budgets > 0):
|
||||
return
|
||||
torch.where(
|
||||
next_token_ids == self.think_end_ids,
|
||||
torch.tensor(-1, device=self.thinking_budgets.device),
|
||||
self.thinking_budgets,
|
||||
out=self.thinking_budgets,
|
||||
)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
if not self.grammars:
|
||||
self.vocab_mask = None
|
||||
|
||||
@@ -30,6 +30,7 @@ class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 128,
|
||||
thinking_budget: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
@@ -57,6 +58,7 @@ class SamplingParams:
|
||||
self.stop_token_ids = set(stop_token_ids)
|
||||
else:
|
||||
self.stop_token_ids = None
|
||||
self.thinking_budget = thinking_budget
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
|
||||
Reference in New Issue
Block a user