[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
406
vllm/model_executor/layers/rejection_sampler.py
Normal file
406
vllm/model_executor/layers/rejection_sampler.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeStochasticBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if find_spec("flashinfer"):
|
||||
"""
|
||||
Consider utilizing the FlashInfer rejection sampling kernel initially,
|
||||
as it employs a dedicated kernel rather than relying on
|
||||
Torch tensor operations. This design choice helps to fuse operations,
|
||||
reduce memory I/O, and consequently enhances performance.
|
||||
"""
|
||||
from flashinfer.sampling import chain_speculative_sampling
|
||||
else:
|
||||
chain_speculative_sampling = None
|
||||
|
||||
|
||||
class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||
Language Model Decoding with Speculative Sampling"
|
||||
https://arxiv.org/pdf/2302.01318.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strict_mode: bool = False,
|
||||
use_flashinfer: Optional[bool] = None):
|
||||
"""Create a rejection sampler.
|
||||
|
||||
Args:
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
use_flashinfer: We will use this parameter to determine whether
|
||||
to use the FlashInfer rejection sampling kernel or not. If it's
|
||||
None, we will use the default value from the environment variable.
|
||||
This parameter is only used for testing purposes.
|
||||
"""
|
||||
super().__init__(strict_mode=strict_mode)
|
||||
if use_flashinfer is None:
|
||||
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
|
||||
chain_speculative_sampling is not None)
|
||||
else:
|
||||
self.use_flashinfer = use_flashinfer
|
||||
|
||||
if self.use_flashinfer:
|
||||
logger.info("Use flashinfer for rejection sampling.")
|
||||
else:
|
||||
logger.info("Use pytorch for rejection sampling.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using rejection sampling. This accepts or rejects
|
||||
tokens proposed by the draft model using the probability of each token
|
||||
according to the draft and target models.
|
||||
|
||||
In the worst case where all draft tokens are rejected, it is guaranteed
|
||||
one correct token will be emitted.
|
||||
|
||||
In the case where all draft tokens are accepted, a bonus token will be
|
||||
accepted as its cheap to have the target model score this speculative
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
target_with_bonus_probs: The probability distribution
|
||||
over token ids given context according to the target model.
|
||||
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
|
||||
|
||||
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
||||
speculative tokens in a sequence are accepted.
|
||||
shape = [batch_size, num_bonus_tokens]
|
||||
|
||||
draft_probs: The probability distribution over token ids given
|
||||
context according to the draft model.
|
||||
shape = [batch_size, num_speculative_tokens, vocab_size]
|
||||
|
||||
draft_token_ids: The token ids that were sampled from the draft
|
||||
probabilities.
|
||||
shape = [batch_size, num_speculative_tokens]
|
||||
|
||||
seeded_seqs: Dict of batch row index to torch generator, for
|
||||
sequences using seeded generation.
|
||||
|
||||
Returns:
|
||||
output_token_ids: The token ids sampled via rejection sampling,
|
||||
or -1 if unable to sample a token because the previous token
|
||||
was rejected.
|
||||
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
|
||||
"""
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
|
||||
# batch_size = 0 when all requests in the batch are
|
||||
# non_spec requests. In this case, output_token_ids is
|
||||
# just an empty tensor.
|
||||
if batch_size == 0:
|
||||
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
|
||||
|
||||
# If use Flashinfer chain_speculative_sampling kernel
|
||||
# for rejection sampling
|
||||
if self.use_flashinfer and chain_speculative_sampling is not None:
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
|
||||
(output_token_ids, accepted_token_num,
|
||||
emitted_token_num) = chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
target_with_bonus_probs,
|
||||
)
|
||||
|
||||
# num_emitted_tokens returned by flashinfer
|
||||
# does not include the bonus token
|
||||
# Flashinfer stops at the first token that violates
|
||||
# the condition p >= q and does not include recovery/bonus token.
|
||||
# Therefore, we need to add batch_size here.
|
||||
self.num_accepted_tokens += accepted_token_num.sum()
|
||||
self.num_emitted_tokens += emitted_token_num.sum() + batch_size
|
||||
self.num_draft_tokens += batch_size * k
|
||||
else:
|
||||
accepted, recovered_token_ids = (
|
||||
self._batch_modified_rejection_sampling(
|
||||
target_with_bonus_probs[:, :-1],
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
seeded_seqs,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
return output_token_ids
|
||||
|
||||
def _batch_modified_rejection_sampling(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform modified rejection sampling on each sequence.
|
||||
|
||||
Returns:
|
||||
A tuple of two tensors:
|
||||
0: A bool tensor of which tokens in each sequence is accepted.
|
||||
shape = [batch_size, k]
|
||||
1: Token ids sampled from a recovered distribution, to be used
|
||||
when a token is rejected.
|
||||
shape = [batch_size, k]
|
||||
"""
|
||||
|
||||
batch_size, k, vocab_size = draft_probs.shape
|
||||
|
||||
# shape [batch_size, k]
|
||||
accepted = self._get_accepted(target_probs, draft_probs,
|
||||
draft_token_ids, seeded_seqs)
|
||||
|
||||
recovered_probs = self._get_recovered_probs(
|
||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||
|
||||
# NOTE: the recovered_probs are overwritten by this method.
|
||||
recovered_token_ids = _multinomial(
|
||||
recovered_probs,
|
||||
num_samples=1,
|
||||
k=k,
|
||||
seeded_seqs=seeded_seqs or {},
|
||||
).reshape(batch_size, k)
|
||||
|
||||
return accepted, recovered_token_ids
|
||||
|
||||
def _create_uniform_samples(self,
|
||||
seeded_seqs: Optional[dict[int,
|
||||
torch.Generator]],
|
||||
batch_size: int, k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
for specific sequences.
|
||||
|
||||
This method creates a tensor of shape `(batch_size, k + 1)` filled
|
||||
with uniform random values in the range [0, 1). If `seeded_seqs`
|
||||
is provided, the sequences corresponding to specific indices
|
||||
will be generated using the provided `torch.Generator` for
|
||||
reproducibility. The other sequences will be generated without
|
||||
a seed.
|
||||
|
||||
Args:
|
||||
seeded_seqs : Optional[dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects. If `None`, all samples are
|
||||
generated without a seed.
|
||||
batch_size : int
|
||||
The number of sequences to generate.
|
||||
k : int
|
||||
The number of random samples per sequence.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(batch_size, k + 1)` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
if not seeded_seqs:
|
||||
return torch.rand(batch_size, k + 1, device=device)
|
||||
|
||||
uniform_rand = torch.empty(batch_size, k + 1, device=device)
|
||||
|
||||
non_seeded_indices = []
|
||||
for idx in range(batch_size):
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.append(idx)
|
||||
else:
|
||||
uniform_rand[idx, :] = torch.rand(1,
|
||||
k + 1,
|
||||
dtype=self.probs_dtype,
|
||||
device=device,
|
||||
generator=generator)
|
||||
if non_seeded_indices:
|
||||
uniform_rand[non_seeded_indices, :] = torch.rand(
|
||||
len(non_seeded_indices),
|
||||
k + 1,
|
||||
dtype=self.probs_dtype,
|
||||
device=device)
|
||||
return uniform_rand
|
||||
|
||||
def _get_accepted(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
r"""Create bool matrix over the proposed draft tokens. If
|
||||
True, then a token can be accepted, else it should be
|
||||
rejected.
|
||||
|
||||
Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
|
||||
$\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
|
||||
to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
|
||||
same conditional probability according to the draft model, the token
|
||||
is accepted with probability:
|
||||
|
||||
$$
|
||||
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
|
||||
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
|
||||
$$
|
||||
|
||||
This implementation does not apply causality. When using the output,
|
||||
if a token is rejected, subsequent tokens should not be used.
|
||||
|
||||
Returns a bool tensor of shape [batch_size, k] specifying which tokens
|
||||
are accepted.
|
||||
"""
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
batch_indices = torch.arange(batch_size,
|
||||
device=target_probs.device)[:, None]
|
||||
probs_indicies = torch.arange(k, device=target_probs.device)
|
||||
|
||||
# shape [batch_size, k]
|
||||
selected_draft_probs = draft_probs[batch_indices, probs_indicies,
|
||||
draft_token_ids]
|
||||
|
||||
# shape [batch_size, k]
|
||||
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
||||
draft_token_ids]
|
||||
|
||||
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
|
||||
k - 1, target_probs.device)
|
||||
|
||||
capped_ratio = torch.minimum(
|
||||
selected_target_probs / selected_draft_probs,
|
||||
torch.full((1, ), 1, device=target_probs.device))
|
||||
accepted = uniform_rand < capped_ratio
|
||||
|
||||
return accepted
|
||||
|
||||
def _get_recovered_probs(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [k, vocab_size]
|
||||
) -> torch.Tensor:
|
||||
r"""Create a probability distribution for each proposed token which can
|
||||
be sampled if the proposed token is rejected.
|
||||
|
||||
When this routine is applied sequentially, the true distribution of the
|
||||
target model is recovered (within hardware numerics).
|
||||
|
||||
The probability distribution used in this rejection case is constructed
|
||||
as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
|
||||
$x$ given context $x_1, \dots, x_n$ according to the target
|
||||
model and $p(x|x_1, \dots, x_n)$, the same conditional probability
|
||||
according to the draft model:
|
||||
|
||||
$$
|
||||
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
|
||||
$$
|
||||
|
||||
where $(f(x))_+$ is defined as:
|
||||
|
||||
$$
|
||||
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
|
||||
$$
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
|
||||
of the draft, target, and recovered probability distributions.
|
||||
|
||||
Returns a tensor of shape [batch_size, k, vocab_size].
|
||||
|
||||
Note:
|
||||
This batches operations on GPU and thus constructs the recovered
|
||||
distribution for all tokens, even if they are accepted. This causes
|
||||
division-by-zero errors, so we use self._smallest_positive_value to
|
||||
avoid that. This introduces some drift to the distribution.
|
||||
"""
|
||||
_, k, _ = draft_probs.shape
|
||||
|
||||
# shape [batch_size, k, vocab_size]
|
||||
difference = target_probs - draft_probs
|
||||
|
||||
# TODO(cade): Can we use logprobs instead of probs, and avoid the
|
||||
# division-by-zero errors without introducing distribution drift?
|
||||
|
||||
# shape [batch_size, k, vocab_size]
|
||||
f = torch.clamp(difference, min=self._smallest_positive_value)
|
||||
|
||||
# shape [batch_size, k, vocab_size]
|
||||
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
|
||||
|
||||
return recovered_probs
|
||||
|
||||
@cached_property
|
||||
def _smallest_positive_value(self) -> float:
|
||||
"""Return the smallest positive value representable by the probs dtype.
|
||||
This value is used when constructing a distribution from which to sample
|
||||
recovered tokens in the first rejection case.
|
||||
|
||||
See _get_recovered_probs for more details
|
||||
|
||||
Note that this isn't actually the smallest positive value representable
|
||||
by float32, but the smallest positive normal value.
|
||||
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
|
||||
"""
|
||||
return torch.finfo(self.probs_dtype).tiny
|
||||
|
||||
|
||||
# torch.multinomial forces a GPU<->CPU sync.
|
||||
# Therefore, we use an optimized implementation instead that skips the sync.
|
||||
# Note that we always sample with replacement.
|
||||
# probs will be modified in place, but this is fine, as we pass
|
||||
# in a copy already.
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
k: int,
|
||||
seeded_seqs: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_samples > 1:
|
||||
# This is equivalent to torch.repeat_interleaved (which also
|
||||
# forces a GPU<->CPU sync).
|
||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||
probs.shape[1]).contiguous().view(
|
||||
-1, probs.shape[1])
|
||||
q = torch.empty_like(probs)
|
||||
if not seeded_seqs:
|
||||
q.exponential_(1.0)
|
||||
else:
|
||||
start = 0
|
||||
for idx in range(len(q) // k):
|
||||
end = start + k
|
||||
generator = seeded_seqs.get(idx)
|
||||
# Note: generator might be None for non seeded
|
||||
q[start:end].exponential_(1.0, generator=generator)
|
||||
start = end
|
||||
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
Reference in New Issue
Block a user