chore: upgrade sgl-kernel 0.0.9.post2 (#5540)

This commit is contained in:
Yineng Zhang
2025-04-18 21:17:23 -07:00
committed by GitHub
parent a6f892e5d0
commit 2c11f9c2eb
3 changed files with 5 additions and 9 deletions

View File

@@ -47,7 +47,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.0.9.post1",
"sgl-kernel==0.0.9.post2",
"flashinfer_python==0.2.3",
"torch==2.5.1",
"torchvision==0.20.1",

View File

@@ -93,25 +93,21 @@ class Sampler(nn.Module):
).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
probs, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
if self.use_nan_detection:
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)