From 71133a0426a331b182ec27294736468879ae21f4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Sep 2025 01:29:52 -0700 Subject: [PATCH] [Auto Sync] Update sampling_batch_info.py (20250909) (#10212) Co-authored-by: github-actions[bot] Co-authored-by: cctry --- .../srt/sampling/sampling_batch_info.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index ec649f479..6ba8a7777 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -67,28 +67,31 @@ class SamplingBatchInfo: logit_bias: Optional[torch.Tensor] = None @classmethod - def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + def _get_global_server_args_dict(cls): from sglang.srt.managers.schedule_batch import global_server_args_dict + return global_server_args_dict + + @classmethod + def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + global_server_args_dict = cls._get_global_server_args_dict() + reqs = batch.reqs device = batch.device - temperatures = ( - torch.tensor( - [r.sampling_params.temperature for r in reqs], - dtype=torch.float, - ) - .view(-1, 1) - .to(device, non_blocking=True) - ) + temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], + dtype=torch.float, + device=device, + ).view(-1, 1) top_ps = torch.tensor( - [r.sampling_params.top_p for r in reqs], dtype=torch.float - ).to(device, non_blocking=True) + [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device + ) top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int32 - ).to(device, non_blocking=True) + [r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device + ) min_ps = torch.tensor( - [r.sampling_params.min_p for r in reqs], dtype=torch.float - ).to(device, non_blocking=True) + [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device + ) logit_bias = None if any(r.sampling_params.logit_bias is not None for r in reqs):