Slightly improve the sampler to skip unnecessary steps (#6956)
This commit is contained in:
@@ -9,10 +9,12 @@ import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -27,6 +29,12 @@ class SamplingBatchInfo:
|
||||
# Whether all requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Whether any requests use top_p sampling
|
||||
need_top_p_sampling: bool
|
||||
|
||||
# Whether any requests use top_k sampling
|
||||
need_top_k_sampling: bool
|
||||
|
||||
# Whether any request needs min_p sampling
|
||||
need_min_p_sampling: bool
|
||||
|
||||
@@ -133,6 +141,8 @@ class SamplingBatchInfo:
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
||||
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
penalizer_orchestrator=penalizer_orchestrator,
|
||||
@@ -167,7 +177,7 @@ class SamplingBatchInfo:
|
||||
|
||||
# Apply the mask
|
||||
for i, grammar in enumerate(self.grammars):
|
||||
if grammar and not grammar.finished:
|
||||
if grammar and not grammar.finished and not grammar.is_terminated():
|
||||
grammar.fill_vocab_mask(self.vocab_mask, i)
|
||||
|
||||
# Move the mask to the device if needed
|
||||
@@ -308,4 +318,6 @@ class SamplingBatchInfo:
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy &= other.is_all_greedy
|
||||
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
TOP_K_ALL = 1 << 30
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
@@ -84,7 +85,7 @@ class SamplingParams:
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
self.top_k = TOP_K_ALL # whole vocabulary
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
|
||||
Reference in New Issue
Block a user