Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.sampler import Sampler, get_top_logprobs
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
@@ -192,7 +191,8 @@ class ModelRunner:
|
||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
||||
|
||||
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
||||
# Need to use xccl for xpu backend in the future
|
||||
elif self.device == "xpu":
|
||||
backend = "gloo"
|
||||
@@ -704,6 +704,7 @@ class ModelRunner:
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
# Apply logit bias
|
||||
sampling_info = forward_batch.sampling_info
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
@@ -714,35 +715,17 @@ class ModelRunner:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = self.sampler(logits, sampling_info)
|
||||
# Sample the next tokens
|
||||
next_token_ids = self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
forward_batch.return_logprob,
|
||||
forward_batch.top_logprobs_nums,
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Apply logit_bias
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits.add_(sampling_info.linear_penalties)
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / sampling_info.scaling_penalties,
|
||||
logits * sampling_info.scaling_penalties,
|
||||
)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if sampling_info.vocab_mask is not None:
|
||||
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def model_is_mrope(self) -> bool:
|
||||
"""Detect if the model has "mrope" rope_scaling type.
|
||||
|
||||
Reference in New Issue
Block a user