Use int64 as indices for set_kv_buffer (#3039)
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||
|
||||
@@ -109,8 +108,6 @@ class Sampler(nn.Module):
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
||||
|
||||
# Attach logprobs to logits_output (in-place modification)
|
||||
if return_logprob:
|
||||
if any(x > 0 for x in top_logprobs_nums):
|
||||
@@ -124,7 +121,7 @@ class Sampler(nn.Module):
|
||||
batch_next_token_ids,
|
||||
]
|
||||
|
||||
return batch_next_token_ids
|
||||
return batch_next_token_ids.to(torch.int32)
|
||||
|
||||
def _apply_custom_logit_processor(
|
||||
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
||||
|
||||
Reference in New Issue
Block a user