feat: add thinking_budget (#6089)

This commit is contained in:
thyecust
2025-05-09 23:22:09 +08:00
committed by GitHub
parent dff0ab92eb
commit 63484f9fd6
9 changed files with 196 additions and 5 deletions

View File

@@ -30,8 +30,13 @@ class SamplingBatchInfo:
# Whether any request needs min_p sampling
need_min_p_sampling: bool
# Use thinking_budget to truncate thinking
num_thinking_tokens: Optional[torch.Tensor] = None
think_end_ids: Optional[torch.Tensor] = None
thinking_budgets: Optional[torch.Tensor] = None
# Masking tensors for grammar-guided structured outputs
vocab_size: int
vocab_size: int = 0
grammars: Optional[List] = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
@@ -76,7 +81,22 @@ class SamplingBatchInfo:
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
think_end_ids = torch.tensor(
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
dtype=torch.int64,
).to(device, non_blocking=True)
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
device, non_blocking=True
)
thinking_budgets = torch.tensor(
[r.sampling_params.thinking_budget or -1 for r in reqs],
dtype=torch.int64,
).to(device, non_blocking=True)
else:
think_end_ids = None
num_thinking_tokens = None
thinking_budgets = None
# Check if any request has custom logit processor
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
@@ -132,6 +152,9 @@ class SamplingBatchInfo:
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
think_end_ids=think_end_ids,
num_thinking_tokens=num_thinking_tokens,
thinking_budgets=thinking_budgets,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
@@ -146,6 +169,35 @@ class SamplingBatchInfo:
def __len__(self):
return len(self.temperatures)
def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
has_budget = self.thinking_budgets > 0
if not has_budget.any():
return
torch.where(
has_budget,
self.num_thinking_tokens + 1,
self.num_thinking_tokens,
out=self.num_thinking_tokens,
)
should_stop = has_budget & (
self.num_thinking_tokens - 1 > self.thinking_budgets
)
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
if len(batch_indices) > 0:
end_token_indices = self.think_end_ids[batch_indices]
next_token_logits[batch_indices, end_token_indices] = 0.0
def update_thinking_budgets(self, next_token_ids: torch.Tensor):
if not torch.any(self.thinking_budgets > 0):
return
torch.where(
next_token_ids == self.think_end_ids,
torch.tensor(-1, device=self.thinking_budgets.device),
self.thinking_budgets,
out=self.thinking_budgets,
)
def update_regex_vocab_mask(self):
if not self.grammars:
self.vocab_mask = None

View File

@@ -30,6 +30,7 @@ class SamplingParams:
def __init__(
self,
max_new_tokens: int = 128,
thinking_budget: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
@@ -57,6 +58,7 @@ class SamplingParams:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.thinking_budget = thinking_budget
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k