[Auto Sync] Update sampling_batch_info.py (20250909) (#10212)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: cctry <shiyang@x.ai>
This commit is contained in:
@@ -67,28 +67,31 @@ class SamplingBatchInfo:
|
||||
logit_bias: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
def _get_global_server_args_dict(cls):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
return global_server_args_dict
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
global_server_args_dict = cls._get_global_server_args_dict()
|
||||
|
||||
reqs = batch.reqs
|
||||
device = batch.device
|
||||
temperatures = (
|
||||
torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
)
|
||||
.view(-1, 1)
|
||||
.to(device, non_blocking=True)
|
||||
)
|
||||
temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
|
||||
)
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
logit_bias = None
|
||||
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
||||
|
||||
Reference in New Issue
Block a user