Minor follow-up fixes for the logprob refactor (#2670)
This commit is contained in:
@@ -35,21 +35,21 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor.
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||||
next_token_logits: torch.Tensor
|
next_token_logits: torch.Tensor
|
||||||
# Used by speculative decoding (EAGLE)
|
# Used by speculative decoding (EAGLE)
|
||||||
# The last hidden layers
|
# The last hidden layers
|
||||||
hidden_states: Optional[torch.Tensor] = None
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler.
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
||||||
# The logprobs of the next tokens. shape: [#seq]
|
# The logprobs of the next tokens. shape: [#seq]
|
||||||
next_token_logprobs: Optional[torch.Tensor] = None
|
next_token_logprobs: Optional[torch.Tensor] = None
|
||||||
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
||||||
next_token_top_logprobs_val: Optional[List] = None
|
next_token_top_logprobs_val: Optional[List] = None
|
||||||
next_token_top_logprobs_idx: Optional[List] = None
|
next_token_top_logprobs_idx: Optional[List] = None
|
||||||
|
|
||||||
## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only.
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||||
# The normlaized logprobs of prompts. shape: [#seq]
|
# The normlaized logprobs of prompts. shape: [#seq]
|
||||||
normalized_prompt_logprobs: torch.Tensor = None
|
normalized_prompt_logprobs: torch.Tensor = None
|
||||||
# The logprobs of input tokens. shape: [#token]
|
# The logprobs of input tokens. shape: [#token]
|
||||||
|
|||||||
@@ -56,7 +56,9 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
||||||
|
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
||||||
|
# so we use the torch implementation.
|
||||||
logprobs = torch.log(
|
logprobs = torch.log(
|
||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
|||||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import Sampler, get_top_logprobs
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -191,10 +191,9 @@ class ModelRunner:
|
|||||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
|
|
||||||
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
|
||||||
# Need to use xccl for xpu backend in the future
|
|
||||||
elif self.device == "xpu":
|
elif self.device == "xpu":
|
||||||
|
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
||||||
|
# Need to use xccl for xpu backend in the future
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
elif self.device == "hpu":
|
elif self.device == "hpu":
|
||||||
backend = "hccl"
|
backend = "hccl"
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
# repetition
|
# repetition
|
||||||
if self.scaling_penalties is not None:
|
if self.scaling_penalties is not None:
|
||||||
logits = torch.where(
|
logits[:] = torch.where(
|
||||||
logits > 0,
|
logits > 0,
|
||||||
logits / self.scaling_penalties,
|
logits / self.scaling_penalties,
|
||||||
logits * self.scaling_penalties,
|
logits * self.scaling_penalties,
|
||||||
@@ -253,5 +253,3 @@ class SamplingBatchInfo:
|
|||||||
# Apply regex vocab_mask
|
# Apply regex vocab_mask
|
||||||
if self.vocab_mask is not None:
|
if self.vocab_mask is not None:
|
||||||
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
"regex": "( Yes| No)",
|
"regex": "( Yes| No)",
|
||||||
},
|
},
|
||||||
"return_logprob": True,
|
"return_logprob": True,
|
||||||
"top_logprobs_num": 5,
|
"top_logprobs_num": 5, # The grammar constraint allows all prefix tokens so we need to use a larger top_k.
|
||||||
"return_text_in_logprobs": True,
|
"return_text_in_logprobs": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user