chore: upgrade sgl-kernel 0.0.9.post2 (#5540)
This commit is contained in:
@@ -47,7 +47,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.0.9.post1",
|
"sgl-kernel==0.0.9.post2",
|
||||||
"flashinfer_python==0.2.3",
|
"flashinfer_python==0.2.3",
|
||||||
"torch==2.5.1",
|
"torch==2.5.1",
|
||||||
"torchvision==0.20.1",
|
"torchvision==0.20.1",
|
||||||
|
|||||||
@@ -93,25 +93,21 @@ class Sampler(nn.Module):
|
|||||||
).clamp(min=torch.finfo(probs.dtype).min)
|
).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
|
||||||
)
|
|
||||||
if sampling_info.need_min_p_sampling:
|
if sampling_info.need_min_p_sampling:
|
||||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||||
batch_next_token_ids = min_p_sampling_from_probs(
|
batch_next_token_ids = min_p_sampling_from_probs(
|
||||||
probs, uniform_samples, sampling_info.min_ps
|
probs, sampling_info.min_ps
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
|
||||||
sampling_info.top_ks,
|
sampling_info.top_ks,
|
||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
filter_apply_order="joint",
|
filter_apply_order="joint",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_nan_detection and not torch.all(success):
|
if self.use_nan_detection:
|
||||||
logger.warning("Detected errors during sampling!")
|
logger.warning("Detected errors during sampling!")
|
||||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ pip install --upgrade pip
|
|||||||
|
|
||||||
# Install flashinfer and sgl-kernel
|
# Install flashinfer and sgl-kernel
|
||||||
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --no-cache-dir
|
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --no-cache-dir
|
||||||
pip install sgl-kernel==0.0.9.post1 --no-cache-dir
|
pip install sgl-kernel==0.0.9.post2 --no-cache-dir
|
||||||
|
|
||||||
# Install the main package
|
# Install the main package
|
||||||
pip install -e "python[all]" --find-links ${FLASHINFER_REPO}
|
pip install -e "python[all]" --find-links ${FLASHINFER_REPO}
|
||||||
|
|||||||
Reference in New Issue
Block a user