Simplify the nan detection and greedy check in sampler (#1709)
This commit is contained in:
@@ -20,6 +20,9 @@ class SamplingBatchInfo:
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# All requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool
|
||||
|
||||
@@ -73,6 +76,7 @@ class SamplingBatchInfo:
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=top_ks.max().item() <= 1,
|
||||
vocab_size=vocab_size,
|
||||
device=batch.input_ids.device,
|
||||
)
|
||||
@@ -204,6 +208,7 @@ class SamplingBatchInfo:
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user