diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index ec649f479..6ba8a7777 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -67,28 +67,31 @@ class SamplingBatchInfo: logit_bias: Optional[torch.Tensor] = None @classmethod - def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + def _get_global_server_args_dict(cls): from sglang.srt.managers.schedule_batch import global_server_args_dict + return global_server_args_dict + + @classmethod + def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + global_server_args_dict = cls._get_global_server_args_dict() + reqs = batch.reqs device = batch.device - temperatures = ( - torch.tensor( - [r.sampling_params.temperature for r in reqs], - dtype=torch.float, - ) - .view(-1, 1) - .to(device, non_blocking=True) - ) + temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], + dtype=torch.float, + device=device, + ).view(-1, 1) top_ps = torch.tensor( - [r.sampling_params.top_p for r in reqs], dtype=torch.float - ).to(device, non_blocking=True) + [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device + ) top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int32 - ).to(device, non_blocking=True) + [r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device + ) min_ps = torch.tensor( - [r.sampling_params.min_p for r in reqs], dtype=torch.float - ).to(device, non_blocking=True) + [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device + ) logit_bias = None if any(r.sampling_params.logit_bias is not None for r in reqs):