Fix bugs in sampler with CUDA graph / torch.compile (#1306)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from flashinfer.sampling import (
|
||||
@@ -9,6 +9,7 @@ from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
from torch.library import custom_op as torch_custom_op
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -30,6 +31,9 @@ class SampleOutput:
|
||||
class Sampler(CustomOp):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# FIXME: torch.multinomial has too many bugs
|
||||
self.forward_native = self.forward_cuda
|
||||
self.is_torch_compile = False
|
||||
|
||||
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# min-token, presence, frequency
|
||||
@@ -46,16 +50,11 @@ class Sampler(CustomOp):
|
||||
|
||||
return logits
|
||||
|
||||
def _get_probs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
is_torch_compile: bool = False,
|
||||
):
|
||||
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(sampling_info.temperatures)
|
||||
if is_torch_compile:
|
||||
if self.is_torch_compile:
|
||||
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
||||
logits.add_(0)
|
||||
|
||||
@@ -91,7 +90,7 @@ class Sampler(CustomOp):
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
batch_next_token_ids, success = flashinfer_top_k_top_p(
|
||||
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
||||
)
|
||||
else:
|
||||
@@ -110,7 +109,7 @@ class Sampler(CustomOp):
|
||||
if isinstance(logits, LogitsProcessorOutput):
|
||||
logits = logits.next_token_logits
|
||||
|
||||
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
||||
probs = self._get_probs(logits, sampling_info)
|
||||
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
@@ -119,6 +118,31 @@ class Sampler(CustomOp):
|
||||
return SampleOutput(success, probs, batch_next_token_ids)
|
||||
|
||||
|
||||
@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
|
||||
def flashinfer_top_k_top_p(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_ks: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
|
||||
return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
|
||||
|
||||
|
||||
@flashinfer_top_k_top_p.register_fake
|
||||
def _(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_ks: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bs = probs.shape[0]
|
||||
return (
|
||||
torch.ones(bs, dtype=torch.bool, device=probs.device),
|
||||
torch.zeros(bs, dtype=torch.int32, device=probs.device),
|
||||
)
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
top_ks: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user