[Auto Sync] Update sampling_batch_info.py (20250909) (#10212)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: cctry <shiyang@x.ai>
This commit is contained in:
Lianmin Zheng
2025-09-09 01:29:52 -07:00
committed by GitHub
parent 2cd94dd07e
commit 71133a0426

View File

@@ -67,28 +67,31 @@ class SamplingBatchInfo:
logit_bias: Optional[torch.Tensor] = None
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict()
reqs = batch.reqs
device = batch.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
)
.view(-1, 1)
.to(device, non_blocking=True)
)
temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
).to(device, non_blocking=True)
[r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
logit_bias = None
if any(r.sampling_params.logit_bias is not None for r in reqs):