Simplify the nan detection and greedy check in sampler (#1709)

This commit is contained in:
Lianmin Zheng
2024-10-18 20:21:24 -07:00
committed by GitHub
parent 2bcfba1b08
commit f0f8a7699b
6 changed files with 24 additions and 7 deletions

View File

@@ -20,6 +20,9 @@ class SamplingBatchInfo:
top_ks: torch.Tensor
min_ps: torch.Tensor
# All requests use greedy sampling
is_all_greedy: bool
# Dispatch in CUDA graph
need_min_p_sampling: bool
@@ -73,6 +76,7 @@ class SamplingBatchInfo:
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=top_ks.max().item() <= 1,
vocab_size=vocab_size,
device=batch.input_ids.device,
)
@@ -204,6 +208,7 @@ class SamplingBatchInfo:
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)