Move sampler into CUDA graph (#1201)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from flashinfer.sampling import (
|
||||
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
|
||||
# TODO: move this dict to another place
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleOutput:
|
||||
success: torch.Tensor
|
||||
probs: torch.Tensor
|
||||
batch_next_token_ids: torch.Tensor
|
||||
|
||||
|
||||
class Sampler(CustomOp):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / sampling_info.scaling_penalties,
|
||||
logits * sampling_info.scaling_penalties,
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
def _get_probs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
is_torch_compile: bool = False,
|
||||
):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(sampling_info.temperatures)
|
||||
if is_torch_compile:
|
||||
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
||||
logits.add_(0)
|
||||
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
logits = sampling_info.penalizer_orchestrator.apply(logits)
|
||||
logits = self._apply_penalties(logits, sampling_info)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return torch.softmax(logits, dim=-1)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
||||
sampling_info: SamplingBatchInfo,
|
||||
):
|
||||
if isinstance(logits, LogitsProcessorOutput):
|
||||
logits = logits.next_token_logits
|
||||
|
||||
probs = self._get_probs(logits, sampling_info)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
if sampling_info.min_ps.any():
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
return SampleOutput(success, probs, batch_next_token_ids)
|
||||
|
||||
return batch_next_token_ids
|
||||
def forward_native(
|
||||
self,
|
||||
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
||||
sampling_info: SamplingBatchInfo,
|
||||
):
|
||||
if isinstance(logits, LogitsProcessorOutput):
|
||||
logits = logits.next_token_logits
|
||||
|
||||
def forward_native():
|
||||
raise NotImplementedError("Native forward is not implemented yet.")
|
||||
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
||||
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
)
|
||||
|
||||
return SampleOutput(success, probs, batch_next_token_ids)
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
# FIXME: torch.multiomial does not support num_samples = 1
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
||||
:, :1
|
||||
]
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Sampling error: {e}")
|
||||
batch_next_token_ids = torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user