Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -17,6 +17,8 @@ import dataclasses
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -33,76 +35,77 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
|
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor.
|
||||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||||
next_token_logits: torch.Tensor
|
next_token_logits: torch.Tensor
|
||||||
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
# Used by speculative decoding (EAGLE)
|
||||||
next_token_logprobs: torch.Tensor = None
|
# The last hidden layers
|
||||||
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler.
|
||||||
|
# The logprobs of the next tokens. shape: [#seq]
|
||||||
|
next_token_logprobs: Optional[torch.Tensor] = None
|
||||||
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
||||||
|
next_token_top_logprobs_val: Optional[List] = None
|
||||||
|
next_token_top_logprobs_idx: Optional[List] = None
|
||||||
|
|
||||||
|
## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only.
|
||||||
# The normlaized logprobs of prompts. shape: [#seq]
|
# The normlaized logprobs of prompts. shape: [#seq]
|
||||||
normalized_prompt_logprobs: torch.Tensor = None
|
normalized_prompt_logprobs: torch.Tensor = None
|
||||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
# The logprobs of input tokens. shape: [#token]
|
||||||
input_token_logprobs: torch.Tensor = None
|
input_token_logprobs: torch.Tensor = None
|
||||||
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
||||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
|
||||||
input_top_logprobs_val: List = None
|
input_top_logprobs_val: List = None
|
||||||
input_top_logprobs_idx: List = None
|
input_top_logprobs_idx: List = None
|
||||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
|
|
||||||
output_top_logprobs_val: List = None
|
|
||||||
output_top_logprobs_idx: List = None
|
|
||||||
|
|
||||||
# Used by speculative decoding (EAGLE)
|
|
||||||
# The output of transformer layers
|
|
||||||
hidden_states: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsMetadata:
|
class LogitsMetadata:
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
top_logprobs_nums: Optional[List[int]]
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
||||||
|
|
||||||
return_logprob: bool = False
|
|
||||||
return_top_logprob: bool = False
|
|
||||||
|
|
||||||
|
extend_return_logprob: bool = False
|
||||||
|
extend_return_top_logprob: bool = False
|
||||||
extend_seq_lens: Optional[torch.Tensor] = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||||
|
|
||||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||||
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
extend_logprob_pruned_lens_cpu = None
|
|
||||||
|
|
||||||
if forward_batch.return_logprob:
|
|
||||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
|
||||||
if forward_batch.forward_mode.is_extend():
|
|
||||||
extend_logprob_pruned_lens_cpu = [
|
|
||||||
extend_len - start_len
|
|
||||||
for extend_len, start_len in zip(
|
|
||||||
forward_batch.extend_seq_lens_cpu,
|
|
||||||
forward_batch.extend_logprob_start_lens_cpu,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return_top_logprob = False
|
|
||||||
|
|
||||||
if forward_batch.spec_info:
|
if forward_batch.spec_info:
|
||||||
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
||||||
else:
|
else:
|
||||||
capture_hidden_mode = CaptureHiddenMode.NULL
|
capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
||||||
|
extend_return_logprob = True
|
||||||
|
extend_return_top_logprob = any(
|
||||||
|
x > 0 for x in forward_batch.top_logprobs_nums
|
||||||
|
)
|
||||||
|
extend_logprob_pruned_lens_cpu = [
|
||||||
|
extend_len - start_len
|
||||||
|
for extend_len, start_len in zip(
|
||||||
|
forward_batch.extend_seq_lens_cpu,
|
||||||
|
forward_batch.extend_logprob_start_lens_cpu,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
extend_return_logprob = extend_return_top_logprob = (
|
||||||
|
extend_logprob_pruned_lens_cpu
|
||||||
|
) = False
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
forward_mode=forward_batch.forward_mode,
|
forward_mode=forward_batch.forward_mode,
|
||||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
capture_hidden_mode=capture_hidden_mode,
|
||||||
return_logprob=forward_batch.return_logprob,
|
extend_return_logprob=extend_return_logprob,
|
||||||
return_top_logprob=return_top_logprob,
|
extend_return_top_logprob=extend_return_top_logprob,
|
||||||
extend_seq_lens=forward_batch.extend_seq_lens,
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||||
capture_hidden_mode=capture_hidden_mode,
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
):
|
):
|
||||||
if isinstance(logits_metadata, ForwardBatch):
|
if isinstance(logits_metadata, ForwardBatch):
|
||||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||||
assert isinstance(logits_metadata, LogitsMetadata)
|
|
||||||
|
|
||||||
# Get the last hidden states and last logits for the next token prediction
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
if (
|
if (
|
||||||
@@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
|
|
||||||
|
# Compute logits
|
||||||
last_logits = self._get_logits(last_hidden, lm_head)
|
last_logits = self._get_logits(last_hidden, lm_head)
|
||||||
if self.do_tensor_parallel_all_gather:
|
if not logits_metadata.extend_return_logprob:
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
# Decode mode or extend mode without return_logprob.
|
||||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
|
||||||
|
|
||||||
if self.final_logit_softcapping:
|
|
||||||
last_logits.div_(self.final_logit_softcapping)
|
|
||||||
torch.tanh(last_logits, out=last_logits)
|
|
||||||
last_logits.mul_(self.final_logit_softcapping)
|
|
||||||
|
|
||||||
# Return only last_logits if logprob is not requested
|
|
||||||
if not logits_metadata.return_logprob:
|
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
hidden_states=(
|
hidden_states=(
|
||||||
@@ -167,95 +161,60 @@ class LogitsProcessor(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
# Slice the requested tokens to compute logprob
|
||||||
last_logits, logits_metadata
|
pt, pruned_states, pruned_input_ids = 0, [], []
|
||||||
|
for start_len, extend_len in zip(
|
||||||
|
logits_metadata.extend_logprob_start_lens_cpu,
|
||||||
|
logits_metadata.extend_seq_lens_cpu,
|
||||||
|
):
|
||||||
|
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||||
|
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||||
|
pt += extend_len
|
||||||
|
|
||||||
|
# Compute the logits of all required tokens
|
||||||
|
pruned_states = torch.cat(pruned_states)
|
||||||
|
del hidden_states
|
||||||
|
input_token_logits = self._get_logits(pruned_states, lm_head)
|
||||||
|
del pruned_states
|
||||||
|
|
||||||
|
# Normalize the logprob w/o temperature, top-p
|
||||||
|
input_logprobs = input_token_logits
|
||||||
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||||
|
input_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
if logits_metadata.forward_mode.is_decode():
|
# Get the logprob of top-k tokens
|
||||||
if logits_metadata.return_top_logprob:
|
if logits_metadata.extend_return_top_logprob:
|
||||||
output_top_logprobs_val, output_top_logprobs_idx = (
|
(
|
||||||
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
|
input_top_logprobs_val,
|
||||||
)
|
input_top_logprobs_idx,
|
||||||
else:
|
) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
||||||
output_top_logprobs_val = output_top_logprobs_idx = None
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=last_logits,
|
|
||||||
next_token_logprobs=last_logprobs,
|
|
||||||
output_top_logprobs_val=output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx=output_top_logprobs_idx,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Slice the requested tokens to compute logprob
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||||
pt, states, pruned_input_ids = 0, [], []
|
|
||||||
for start_len, extend_len in zip(
|
|
||||||
logits_metadata.extend_logprob_start_lens_cpu,
|
|
||||||
logits_metadata.extend_seq_lens_cpu,
|
|
||||||
):
|
|
||||||
states.append(hidden_states[pt + start_len : pt + extend_len])
|
|
||||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
|
||||||
pt += extend_len
|
|
||||||
|
|
||||||
# Compute the logits and logprobs for all required tokens
|
# Compute the normalized logprobs for the requested tokens.
|
||||||
states = torch.cat(states, dim=0)
|
# Note that we pad a zero at the end for easy batching.
|
||||||
all_logits = self._get_logits(states, lm_head)
|
input_token_logprobs = input_logprobs[
|
||||||
if self.do_tensor_parallel_all_gather:
|
torch.arange(input_logprobs.shape[0], device="cuda"),
|
||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.cat(pruned_input_ids)[1:],
|
||||||
|
torch.tensor([0], device="cuda"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
|
input_token_logprobs,
|
||||||
|
logits_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# The LM head's weights may be zero-padded for parallelism. Remove any
|
return LogitsProcessorOutput(
|
||||||
# extra logits that this padding may have produced.
|
next_token_logits=last_logits,
|
||||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||||
|
input_token_logprobs=input_token_logprobs,
|
||||||
if self.final_logit_softcapping:
|
input_top_logprobs_val=input_top_logprobs_val,
|
||||||
all_logits.div_(self.final_logit_softcapping)
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||||
torch.tanh(all_logits, out=all_logits)
|
)
|
||||||
all_logits.mul_(self.final_logit_softcapping)
|
|
||||||
|
|
||||||
all_logprobs = all_logits
|
|
||||||
del all_logits, hidden_states
|
|
||||||
|
|
||||||
all_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
|
||||||
all_logprobs, logits_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
|
||||||
if logits_metadata.return_top_logprob:
|
|
||||||
(
|
|
||||||
input_top_logprobs_val,
|
|
||||||
input_top_logprobs_idx,
|
|
||||||
output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx,
|
|
||||||
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
|
||||||
else:
|
|
||||||
input_top_logprobs_val = input_top_logprobs_idx = (
|
|
||||||
output_top_logprobs_val
|
|
||||||
) = output_top_logprobs_idx = None
|
|
||||||
|
|
||||||
# Compute the normalized logprobs for the requested tokens.
|
|
||||||
# Note that we pad a zero at the end for easy batching.
|
|
||||||
input_token_logprobs = all_logprobs[
|
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.cat(pruned_input_ids)[1:],
|
|
||||||
torch.tensor([0], device="cuda"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
|
||||||
input_token_logprobs,
|
|
||||||
logits_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
|
||||||
next_token_logits=last_logits,
|
|
||||||
next_token_logprobs=last_logprobs,
|
|
||||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
|
||||||
input_token_logprobs=input_token_logprobs,
|
|
||||||
input_top_logprobs_val=input_top_logprobs_val,
|
|
||||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
|
||||||
output_top_logprobs_val=output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx=output_top_logprobs_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_logits(
|
def _get_logits(
|
||||||
self,
|
self,
|
||||||
@@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module):
|
|||||||
# GGUF models
|
# GGUF models
|
||||||
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
||||||
|
|
||||||
# Optional scaling factor
|
|
||||||
if self.logit_scale is not None:
|
if self.logit_scale is not None:
|
||||||
logits.mul_(self.logit_scale) # In-place multiply
|
logits.mul_(self.logit_scale)
|
||||||
|
|
||||||
|
if self.do_tensor_parallel_all_gather:
|
||||||
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
|
||||||
|
# Compute the normalized logprobs for the requested tokens.
|
||||||
|
# Note that we pad a zero at the end for easy batching.
|
||||||
|
logits = logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
|
if self.final_logit_softcapping:
|
||||||
|
fused_softcap(logits, self.final_logit_softcapping)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -302,90 +271,73 @@ class LogitsProcessor(nn.Module):
|
|||||||
values = ret.values.tolist()
|
values = ret.values.tolist()
|
||||||
indices = ret.indices.tolist()
|
indices = ret.indices.tolist()
|
||||||
|
|
||||||
if logits_metadata.forward_mode.is_decode():
|
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
||||||
output_top_logprobs_val = []
|
|
||||||
output_top_logprobs_idx = []
|
|
||||||
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
|
||||||
output_top_logprobs_val.append(values[i][:k])
|
|
||||||
output_top_logprobs_idx.append(indices[i][:k])
|
|
||||||
return None, None, output_top_logprobs_val, output_top_logprobs_idx
|
|
||||||
else:
|
|
||||||
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
|
||||||
output_top_logprobs_val, output_top_logprobs_idx = [], []
|
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for k, pruned_len in zip(
|
for k, pruned_len in zip(
|
||||||
logits_metadata.top_logprobs_nums,
|
logits_metadata.top_logprobs_nums,
|
||||||
logits_metadata.extend_logprob_pruned_lens_cpu,
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||||
):
|
):
|
||||||
if pruned_len <= 0:
|
if pruned_len <= 0:
|
||||||
input_top_logprobs_val.append([])
|
input_top_logprobs_val.append([])
|
||||||
input_top_logprobs_idx.append([])
|
input_top_logprobs_idx.append([])
|
||||||
output_top_logprobs_val.append([])
|
continue
|
||||||
output_top_logprobs_idx.append([])
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_top_logprobs_val.append(
|
input_top_logprobs_val.append(
|
||||||
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
||||||
)
|
|
||||||
input_top_logprobs_idx.append(
|
|
||||||
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
|
||||||
)
|
|
||||||
output_top_logprobs_val.append(
|
|
||||||
list(
|
|
||||||
values[pt + pruned_len - 1][:k],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
output_top_logprobs_idx.append(
|
|
||||||
list(
|
|
||||||
indices[pt + pruned_len - 1][:k],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pt += pruned_len
|
|
||||||
|
|
||||||
return (
|
|
||||||
input_top_logprobs_val,
|
|
||||||
input_top_logprobs_idx,
|
|
||||||
output_top_logprobs_val,
|
|
||||||
output_top_logprobs_idx,
|
|
||||||
)
|
)
|
||||||
|
input_top_logprobs_idx.append(
|
||||||
|
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
||||||
|
)
|
||||||
|
pt += pruned_len
|
||||||
|
|
||||||
|
return input_top_logprobs_val, input_top_logprobs_idx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_temp_top_p_normalized_logprobs(
|
def compute_temp_top_p_normalized_logprobs(
|
||||||
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# TODO: Implement the temp and top-p normalization
|
||||||
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def test():
|
@triton.jit
|
||||||
all_logprobs = torch.tensor(
|
def fused_softcap_kernel(
|
||||||
# s s s
|
full_logits_ptr,
|
||||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
softcapping_value,
|
||||||
dtype=torch.float32,
|
n_elements,
|
||||||
device="cuda",
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
x = tl.load(full_logits_ptr + offsets, mask=mask)
|
||||||
|
|
||||||
|
# Perform operations in-place
|
||||||
|
x = x / softcapping_value
|
||||||
|
|
||||||
|
# Manual tanh implementation using exp
|
||||||
|
exp2x = tl.exp(2 * x)
|
||||||
|
x = (exp2x - 1) / (exp2x + 1)
|
||||||
|
|
||||||
|
x = x * softcapping_value
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
tl.store(full_logits_ptr + offsets, x, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_softcap(full_logits, final_logit_softcapping):
|
||||||
|
n_elements = full_logits.numel()
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
|
||||||
|
|
||||||
|
fused_softcap_kernel[grid](
|
||||||
|
full_logits_ptr=full_logits,
|
||||||
|
softcapping_value=final_logit_softcapping,
|
||||||
|
n_elements=n_elements,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
return full_logits
|
||||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
token_logprobs = all_logprobs[
|
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
|
||||||
]
|
|
||||||
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
|
||||||
|
|
||||||
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
|
||||||
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
|
||||||
end = start + seq_lens - 2
|
|
||||||
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
|
||||||
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
|
||||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
|
||||||
|
|
||||||
# assert logprobs == [2, _, 2, 4, _]
|
|
||||||
print("token logprobs", token_logprobs)
|
|
||||||
print("start", start)
|
|
||||||
print("end", end)
|
|
||||||
print("sum_logp", sum_logp)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test()
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Union
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
logits_output: LogitsProcessorOutput,
|
||||||
sampling_info: SamplingBatchInfo,
|
sampling_info: SamplingBatchInfo,
|
||||||
|
return_logprob: bool,
|
||||||
|
top_logprobs_nums: List[int],
|
||||||
):
|
):
|
||||||
if isinstance(logits, LogitsProcessorOutput):
|
logits = logits_output.next_token_logits
|
||||||
logits = logits.next_token_logits
|
|
||||||
|
|
||||||
logits = logits.contiguous()
|
|
||||||
|
|
||||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
|
|||||||
if sampling_info.is_all_greedy:
|
if sampling_info.is_all_greedy:
|
||||||
# Use torch.argmax if all requests use greedy sampling
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
batch_next_token_ids = torch.argmax(logits, -1)
|
batch_next_token_ids = torch.argmax(logits, -1)
|
||||||
|
if return_logprob:
|
||||||
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
else:
|
else:
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits.div_(sampling_info.temperatures)
|
logits.div_(sampling_info.temperatures)
|
||||||
@@ -54,6 +55,12 @@ class Sampler(nn.Module):
|
|||||||
del logits
|
del logits
|
||||||
|
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
|
if return_logprob:
|
||||||
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
|
||||||
|
logprobs = torch.log(
|
||||||
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
|
)
|
||||||
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
@@ -76,6 +83,7 @@ class Sampler(nn.Module):
|
|||||||
if self.use_nan_detectioin and not torch.all(success):
|
if self.use_nan_detectioin and not torch.all(success):
|
||||||
logger.warning("Detected errors during sampling!")
|
logger.warning("Detected errors during sampling!")
|
||||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||||
|
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
# A slower fallback implementation with torch native operations.
|
# A slower fallback implementation with torch native operations.
|
||||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
@@ -85,12 +93,31 @@ class Sampler(nn.Module):
|
|||||||
sampling_info.min_ps,
|
sampling_info.min_ps,
|
||||||
sampling_info.need_min_p_sampling,
|
sampling_info.need_min_p_sampling,
|
||||||
)
|
)
|
||||||
|
if return_logprob:
|
||||||
|
logprobs = torch.log(
|
||||||
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch_next_token_ids.to(torch.int32)
|
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
||||||
|
|
||||||
|
# Attach logprobs to logits_output (in-place modification)
|
||||||
|
if return_logprob:
|
||||||
|
if any(x > 0 for x in top_logprobs_nums):
|
||||||
|
(
|
||||||
|
logits_output.next_token_top_logprobs_val,
|
||||||
|
logits_output.next_token_top_logprobs_idx,
|
||||||
|
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
||||||
|
|
||||||
|
logits_output.next_token_logprobs = logprobs[
|
||||||
|
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
||||||
|
batch_next_token_ids,
|
||||||
|
]
|
||||||
|
|
||||||
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
def top_p_normalize_probs(
|
def top_p_normalize_probs_torch(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
top_ps: torch.Tensor,
|
top_ps: torch.Tensor,
|
||||||
):
|
):
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
||||||
return top_p_renorm_prob(probs, top_ps)
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
||||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
|
||||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
||||||
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||||
else:
|
max_k = max(top_logprobs_nums)
|
||||||
raise ValueError(
|
ret = logprobs.topk(max_k, dim=1)
|
||||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
values = ret.values.tolist()
|
||||||
)
|
indices = ret.indices.tolist()
|
||||||
|
|
||||||
|
output_top_logprobs_val = []
|
||||||
|
output_top_logprobs_idx = []
|
||||||
|
for i, k in enumerate(top_logprobs_nums):
|
||||||
|
output_top_logprobs_val.append(values[i][:k])
|
||||||
|
output_top_logprobs_idx.append(indices[i][:k])
|
||||||
|
return output_top_logprobs_val, output_top_logprobs_idx
|
||||||
|
|||||||
@@ -974,12 +974,10 @@ class Scheduler:
|
|||||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||||
else:
|
else:
|
||||||
# Move next_token_ids and logprobs to cpu
|
# Move next_token_ids and logprobs to cpu
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
logits_output.next_token_logprobs[
|
logits_output.next_token_logprobs.tolist()
|
||||||
torch.arange(len(next_token_ids), device=self.device),
|
|
||||||
next_token_ids,
|
|
||||||
].tolist()
|
|
||||||
)
|
)
|
||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
logits_output.input_token_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
@@ -987,7 +985,6 @@ class Scheduler:
|
|||||||
logits_output.normalized_prompt_logprobs = (
|
logits_output.normalized_prompt_logprobs = (
|
||||||
logits_output.normalized_prompt_logprobs.tolist()
|
logits_output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
@@ -1064,13 +1061,9 @@ class Scheduler:
|
|||||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||||
next_token_logprobs = logits_output.next_token_logprobs
|
next_token_logprobs = logits_output.next_token_logprobs
|
||||||
else:
|
else:
|
||||||
# Move next_token_ids and logprobs to cpu
|
|
||||||
if batch.return_logprob:
|
|
||||||
next_token_logprobs = logits_output.next_token_logprobs[
|
|
||||||
torch.arange(len(next_token_ids), device=self.device),
|
|
||||||
next_token_ids,
|
|
||||||
].tolist()
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
if batch.return_logprob:
|
||||||
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||||
|
|
||||||
self.token_to_kv_pool.free_group_begin()
|
self.token_to_kv_pool.free_group_begin()
|
||||||
|
|
||||||
@@ -1095,10 +1088,10 @@ class Scheduler:
|
|||||||
req.output_token_logprobs_idx.append(next_token_id)
|
req.output_token_logprobs_idx.append(next_token_id)
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs_val.append(
|
req.output_top_logprobs_val.append(
|
||||||
logits_output.output_top_logprobs_val[i]
|
logits_output.next_token_top_logprobs_val[i]
|
||||||
)
|
)
|
||||||
req.output_top_logprobs_idx.append(
|
req.output_top_logprobs_idx.append(
|
||||||
logits_output.output_top_logprobs_idx[i]
|
logits_output.next_token_top_logprobs_idx[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
@@ -1200,8 +1193,9 @@ class Scheduler:
|
|||||||
req.output_top_logprobs_idx.extend(
|
req.output_top_logprobs_idx.extend(
|
||||||
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
||||||
)
|
)
|
||||||
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
|
|
||||||
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
||||||
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
||||||
|
|
||||||
return num_input_logprobs
|
return num_input_logprobs
|
||||||
|
|
||||||
|
|||||||
@@ -144,10 +144,9 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
# Copy results to the CPU
|
# Copy results to the CPU
|
||||||
if model_worker_batch.return_logprob:
|
if model_worker_batch.return_logprob:
|
||||||
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
|
logits_output.next_token_logprobs = (
|
||||||
torch.arange(len(next_token_ids), device=self.device),
|
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
||||||
next_token_ids,
|
)
|
||||||
].to("cpu", non_blocking=True)
|
|
||||||
if logits_output.input_token_logprobs is not None:
|
if logits_output.input_token_logprobs is not None:
|
||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||||
|
|||||||
@@ -392,34 +392,7 @@ class CudaGraphRunner:
|
|||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
next_token_logits = self.output_buffers[bs][:raw_bs]
|
next_token_logits = self.output_buffers[bs][:raw_bs]
|
||||||
|
|
||||||
# Extract logprobs
|
logits_output = LogitsProcessorOutput(
|
||||||
if forward_batch.return_logprob:
|
next_token_logits=next_token_logits,
|
||||||
logits_metadata = LogitsMetadata(
|
)
|
||||||
forward_mode=ForwardMode.DECODE,
|
|
||||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
|
||||||
)
|
|
||||||
next_token_logprobs = (
|
|
||||||
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
|
|
||||||
next_token_logits, logits_metadata
|
|
||||||
)
|
|
||||||
)
|
|
||||||
logits_output = LogitsProcessorOutput(
|
|
||||||
next_token_logits=next_token_logits,
|
|
||||||
next_token_logprobs=next_token_logprobs,
|
|
||||||
)
|
|
||||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
|
||||||
if return_top_logprob:
|
|
||||||
(
|
|
||||||
logits_output.output_top_logprobs_val,
|
|
||||||
logits_output.output_top_logprobs_idx,
|
|
||||||
) = LogitsProcessor.get_top_logprobs(
|
|
||||||
next_token_logprobs, logits_metadata
|
|
||||||
)[
|
|
||||||
2:4
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
logits_output = LogitsProcessorOutput(
|
|
||||||
next_token_logits=next_token_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits_output
|
return logits_output
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
|||||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import Sampler
|
from sglang.srt.layers.sampler import Sampler, get_top_logprobs
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
@@ -192,7 +191,8 @@ class ModelRunner:
|
|||||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
|
||||||
|
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
||||||
# Need to use xccl for xpu backend in the future
|
# Need to use xccl for xpu backend in the future
|
||||||
elif self.device == "xpu":
|
elif self.device == "xpu":
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
@@ -704,6 +704,7 @@ class ModelRunner:
|
|||||||
def sample(
|
def sample(
|
||||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Apply logit bias
|
||||||
sampling_info = forward_batch.sampling_info
|
sampling_info = forward_batch.sampling_info
|
||||||
if sampling_info.sampling_info_done:
|
if sampling_info.sampling_info_done:
|
||||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||||
@@ -714,35 +715,17 @@ class ModelRunner:
|
|||||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||||
sampling_info.update_regex_vocab_mask()
|
sampling_info.update_regex_vocab_mask()
|
||||||
sampling_info.update_penalties()
|
sampling_info.update_penalties()
|
||||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens
|
||||||
next_token_ids = self.sampler(logits, sampling_info)
|
next_token_ids = self.sampler(
|
||||||
|
logits_output,
|
||||||
|
sampling_info,
|
||||||
|
forward_batch.return_logprob,
|
||||||
|
forward_batch.top_logprobs_nums,
|
||||||
|
)
|
||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
|
||||||
# Apply logit_bias
|
|
||||||
if sampling_info.logit_bias is not None:
|
|
||||||
logits.add_(sampling_info.logit_bias)
|
|
||||||
|
|
||||||
# min-token, presence, frequency
|
|
||||||
if sampling_info.linear_penalties is not None:
|
|
||||||
logits.add_(sampling_info.linear_penalties)
|
|
||||||
|
|
||||||
# repetition
|
|
||||||
if sampling_info.scaling_penalties is not None:
|
|
||||||
logits = torch.where(
|
|
||||||
logits > 0,
|
|
||||||
logits / sampling_info.scaling_penalties,
|
|
||||||
logits * sampling_info.scaling_penalties,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply regex vocab_mask
|
|
||||||
if sampling_info.vocab_mask is not None:
|
|
||||||
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_is_mrope(self) -> bool:
|
def model_is_mrope(self) -> bool:
|
||||||
"""Detect if the model has "mrope" rope_scaling type.
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
|
|||||||
@@ -232,3 +232,26 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def apply_logits_bias(self, logits: torch.Tensor):
|
||||||
|
# Apply logit_bias
|
||||||
|
if self.logit_bias is not None:
|
||||||
|
logits.add_(self.logit_bias)
|
||||||
|
|
||||||
|
# min-token, presence, frequency
|
||||||
|
if self.linear_penalties is not None:
|
||||||
|
logits.add_(self.linear_penalties)
|
||||||
|
|
||||||
|
# repetition
|
||||||
|
if self.scaling_penalties is not None:
|
||||||
|
logits = torch.where(
|
||||||
|
logits > 0,
|
||||||
|
logits / self.scaling_penalties,
|
||||||
|
logits * self.scaling_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply regex vocab_mask
|
||||||
|
if self.vocab_mask is not None:
|
||||||
|
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import requests
|
|||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
cls.model,
|
||||||
|
|||||||
@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
max_diff = np.max(diff)
|
max_diff = np.max(diff)
|
||||||
self.assertLess(max_diff, 0.25)
|
self.assertLess(max_diff, 0.25)
|
||||||
|
|
||||||
|
def test_logprob_grammar(self):
|
||||||
|
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||||
|
allowed_tokens = [" Yes", " No"]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompts,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"max_new_tokens": 1,
|
||||||
|
"regex": "( Yes| No)",
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"top_logprobs_num": 5,
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
|
||||||
|
print(f"{output_top_logprobs=}")
|
||||||
|
|
||||||
|
# Parse results
|
||||||
|
# This is becaues the grammar constraint allows all prefix tokens
|
||||||
|
logprobs = [None] * 2
|
||||||
|
for i in range(len(output_top_logprobs)):
|
||||||
|
try:
|
||||||
|
idx = allowed_tokens.index(output_top_logprobs[i][2])
|
||||||
|
except ValueError:
|
||||||
|
# Not found
|
||||||
|
continue
|
||||||
|
logprobs[idx] = output_top_logprobs[i][0]
|
||||||
|
|
||||||
|
self.assertTrue(all(x is not None for x in logprobs))
|
||||||
|
|
||||||
def test_get_server_info(self):
|
def test_get_server_info(self):
|
||||||
response = requests.get(self.base_url + "/get_server_info")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|||||||
Reference in New Issue
Block a user