Slightly improve the sampler to skip unnecessary steps (#6956)

This commit is contained in:
Lianmin Zheng
2025-06-08 03:18:54 -07:00
committed by GitHub
parent 6c0a48282a
commit 608668e143
7 changed files with 109 additions and 93 deletions

View File

@@ -9,10 +9,12 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
logger = logging.getLogger(__name__)
@@ -27,6 +29,12 @@ class SamplingBatchInfo:
# Whether all requests use greedy sampling
is_all_greedy: bool
# Whether any requests use top_p sampling
need_top_p_sampling: bool
# Whether any requests use top_k sampling
need_top_k_sampling: bool
# Whether any request needs min_p sampling
need_min_p_sampling: bool
@@ -133,6 +141,8 @@ class SamplingBatchInfo:
top_ks=top_ks,
min_ps=min_ps,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
penalizer_orchestrator=penalizer_orchestrator,
@@ -167,7 +177,7 @@ class SamplingBatchInfo:
# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar and not grammar.finished:
if grammar and not grammar.finished and not grammar.is_terminated():
grammar.fill_vocab_mask(self.vocab_mask, i)
# Move the mask to the device if needed
@@ -308,4 +318,6 @@ class SamplingBatchInfo:
setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy &= other.is_all_greedy
self.need_top_p_sampling |= other.need_top_p_sampling
self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling

View File

@@ -16,6 +16,7 @@
from typing import Any, Dict, List, Optional, Union
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
class SamplingParams:
@@ -84,7 +85,7 @@ class SamplingParams:
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary
self.top_k = TOP_K_ALL # whole vocabulary
def verify(self):
if self.temperature < 0.0: