Slightly improve the sampler to skip unnecessary steps (#6956)
This commit is contained in:
@@ -9,10 +9,12 @@ import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -27,6 +29,12 @@ class SamplingBatchInfo:
|
||||
# Whether all requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Whether any requests use top_p sampling
|
||||
need_top_p_sampling: bool
|
||||
|
||||
# Whether any requests use top_k sampling
|
||||
need_top_k_sampling: bool
|
||||
|
||||
# Whether any request needs min_p sampling
|
||||
need_min_p_sampling: bool
|
||||
|
||||
@@ -133,6 +141,8 @@ class SamplingBatchInfo:
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
||||
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
penalizer_orchestrator=penalizer_orchestrator,
|
||||
@@ -167,7 +177,7 @@ class SamplingBatchInfo:
|
||||
|
||||
# Apply the mask
|
||||
for i, grammar in enumerate(self.grammars):
|
||||
if grammar and not grammar.finished:
|
||||
if grammar and not grammar.finished and not grammar.is_terminated():
|
||||
grammar.fill_vocab_mask(self.vocab_mask, i)
|
||||
|
||||
# Move the mask to the device if needed
|
||||
@@ -308,4 +318,6 @@ class SamplingBatchInfo:
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy &= other.is_all_greedy
|
||||
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
Reference in New Issue
Block a user