Use int64 as indices for set_kv_buffer (#3039)

This commit is contained in:
Lianmin Zheng
2025-01-21 19:46:09 -08:00
committed by GitHub
parent a42213dbd4
commit 3d8f1c9bcf
6 changed files with 30 additions and 37 deletions

View File

@@ -1,12 +1,11 @@
import logging
from typing import Dict, List
from typing import List
import torch
from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
@@ -109,8 +108,6 @@ class Sampler(nn.Module):
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
# Attach logprobs to logits_output (in-place modification)
if return_logprob:
if any(x > 0 for x in top_logprobs_nums):
@@ -124,7 +121,7 @@ class Sampler(nn.Module):
batch_next_token_ids,
]
return batch_next_token_ids
return batch_next_token_ids.to(torch.int32)
def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo