chore: upgrade sgl-kernel 0.0.9.post2 (#5540)

This commit is contained in:
Yineng Zhang
2025-04-18 21:17:23 -07:00
committed by GitHub
parent a6f892e5d0
commit 2c11f9c2eb
3 changed files with 5 additions and 9 deletions

View File

@@ -93,25 +93,21 @@ class Sampler(nn.Module):
).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
probs, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
if self.use_nan_detection:
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)