diff --git a/python/pyproject.toml b/python/pyproject.toml index 0b22b8cb9..5dbe0bf41 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -47,7 +47,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.0.9.post1", + "sgl-kernel==0.0.9.post2", "flashinfer_python==0.2.3", "torch==2.5.1", "torchvision==0.20.1", diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index fcf2af9ea..e0f434a19 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -93,25 +93,21 @@ class Sampler(nn.Module): ).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) batch_next_token_ids = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps + probs, sampling_info.min_ps ) else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + batch_next_token_ids = top_k_top_p_sampling_from_probs( probs, - uniform_samples, sampling_info.top_ks, sampling_info.top_ps, filter_apply_order="joint", ) - if self.use_nan_detection and not torch.all(success): + if self.use_nan_detection: logger.warning("Detected errors during sampling!") batch_next_token_ids = torch.zeros_like(batch_next_token_ids) diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 99326bdd2..b7528f1d8 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -20,7 +20,7 @@ pip install --upgrade pip # Install flashinfer and sgl-kernel pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --no-cache-dir -pip install sgl-kernel==0.0.9.post1 --no-cache-dir +pip install sgl-kernel==0.0.9.post2 --no-cache-dir # Install the main package pip install -e "python[all]" --find-links ${FLASHINFER_REPO}