init src 0.9.2
This commit is contained in:
0
vllm/v1/sample/__init__.py
Normal file
0
vllm/v1/sample/__init__.py
Normal file
517
vllm/v1/sample/logits_processor.py
Normal file
517
vllm/v1/sample/logits_processor.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = 0
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = 1
|
||||
|
||||
|
||||
# (index, params, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Note: each added request is represented as
|
||||
# (index, params, output_tok_ids)
|
||||
# Key assumption: output_tok_ids is a reference to the
|
||||
# request's running output tokens list; in this way
|
||||
# the logits processors always see the latest list of
|
||||
# generated tokens
|
||||
removed: Sequence[RemovedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
moved: list[MovedRequest]
|
||||
added: list[AddedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: Optional[list[RemovedRequest]] = None,
|
||||
moved: Optional[list[MovedRequest]] = None,
|
||||
added: Optional[list[AddedRequest]] = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.moved = moved or []
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from
|
||||
the persistent batch.
|
||||
|
||||
Must not be called after the first time
|
||||
self.removed, self.pop_removed() or
|
||||
self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> Optional[int]:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> Optional[int]:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure
|
||||
and reset internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
# Reset removed/moved/added update lists
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
TODO(andy): won't be utilized until logits
|
||||
processors are user-extensible
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional[BatchUpdate],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update is non-None iff there have been
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogitsProcessorManager:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # argmax-invariant logitsprocs
|
||||
non_argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # non-argmax-invariant logitsprocs
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
|
||||
|
||||
###### ----- Built-in LogitsProcessor impls below here
|
||||
|
||||
|
||||
class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, max_num_reqs: int, pin_memory: bool,
|
||||
device: DeviceLikeType):
|
||||
super().__init__()
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
# Pre-allocated device tensor
|
||||
self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
# Current slice of the device tensor
|
||||
self.min_p: torch.Tensor = self.min_p_device[:0]
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Min-p never impacts greedy sampling"""
|
||||
return True
|
||||
|
||||
def get_min_p_by_index(self, index: int) -> float:
|
||||
return float(self.min_p_cpu[index])
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
needs_update = False
|
||||
# Process added requests.
|
||||
for index, params, _ in batch_update.added:
|
||||
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
|
||||
if self.min_p_cpu[index] != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
if min_p:
|
||||
self.min_p_count += 1
|
||||
|
||||
if self.min_p_count:
|
||||
# Process removed requests.
|
||||
needs_update |= bool(batch_update.removed)
|
||||
for index in batch_update.removed:
|
||||
if self.min_p_cpu[index]:
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
change = (min_p_a :=
|
||||
self.min_p_cpu[adx]) != (min_p_b :=
|
||||
self.min_p_cpu[bdx])
|
||||
needs_update |= change
|
||||
if change:
|
||||
self.min_p_cpu[bdx] = min_p_a
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
self.min_p_cpu[adx] = min_p_b
|
||||
|
||||
# Update tensors if needed.
|
||||
size = batch_update.batch_size
|
||||
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
|
||||
self.min_p = self.min_p_device[:size]
|
||||
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
|
||||
self.min_p.unsqueeze_(1)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.min_p_count:
|
||||
return logits
|
||||
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Adjust min_p
|
||||
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
||||
# Identify valid tokens using threshold comparison
|
||||
invalid_token_mask = probability_values < adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits[invalid_token_mask] = -float('inf')
|
||||
return logits
|
||||
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
super().__init__()
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.bias_tensor: torch.Tensor = torch.tensor(())
|
||||
self.logits_slice = (self._device_tensor([], torch.int32),
|
||||
self._device_tensor([], torch.int32))
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Logit bias can rebalance token probabilities and change the
|
||||
outcome of argmax in greedy sampling."""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
needs_update = bool(batch_update.added)
|
||||
for index, params, _ in batch_update.added:
|
||||
if isinstance(params, SamplingParams) and (lb :=
|
||||
params.logit_bias):
|
||||
self.biases[index] = lb
|
||||
else:
|
||||
self.biases.pop(index, None)
|
||||
|
||||
if self.biases:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
if self.biases.pop(index, None):
|
||||
needs_update = True
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
||||
for a_index, b_index, direct in batch_update.moved:
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if (a_entry := self.biases.pop(a_index, None)) is None:
|
||||
if self.biases.pop(b_index, None) is not None:
|
||||
needs_update = True
|
||||
else:
|
||||
self.biases[b_index] = a_entry
|
||||
needs_update = True
|
||||
else:
|
||||
a_entry = self.biases.pop(a_index, None)
|
||||
if (b_entry := self.biases.pop(b_index, None)) is not None:
|
||||
self.biases[a_index] = b_entry
|
||||
needs_update = True
|
||||
if a_entry is not None:
|
||||
self.biases[b_index] = a_entry
|
||||
needs_update = True
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs, tok_ids, biases = [], [], []
|
||||
for req, lb in self.biases.items():
|
||||
reqs.extend([req] * len(lb))
|
||||
tok_ids.extend(lb.keys())
|
||||
biases.extend(lb.values())
|
||||
|
||||
self.bias_tensor = self._device_tensor(biases, torch.float32)
|
||||
self.logits_slice = (self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32))
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (torch.tensor(data,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory).to(device=self.device,
|
||||
non_blocking=True))
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.biases:
|
||||
logits[self.logits_slice] += self.bias_tensor
|
||||
return logits
|
||||
|
||||
|
||||
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
# index -> (min_toks, output_token_ids, stop_token_ids)
|
||||
super().__init__()
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
# (req_idx_tensor,eos_tok_id_tensor)
|
||||
self.logits_slice: tuple[torch.Tensor,
|
||||
torch.Tensor] = (self._device_tensor(
|
||||
[], torch.int32),
|
||||
self._device_tensor(
|
||||
[], torch.int32))
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""By censoring stop tokens, min-tokens can change the outcome
|
||||
of the argmax operation in greedy sampling."""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
needs_update = False
|
||||
|
||||
if batch_update:
|
||||
# Process added requests.
|
||||
needs_update |= bool(batch_update.added)
|
||||
for index, params, output_tok_ids in batch_update.added:
|
||||
if (isinstance(params, SamplingParams)
|
||||
and (min_tokens := params.min_tokens)
|
||||
and len(output_tok_ids) < min_tokens):
|
||||
# Replace request metadata at batch index
|
||||
self.min_toks[index] = (min_tokens, output_tok_ids,
|
||||
params.all_stop_token_ids)
|
||||
else:
|
||||
# Drop request metadata at batch index
|
||||
self.min_toks.pop(index, None)
|
||||
|
||||
if self.min_toks:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
if self.min_toks.pop(index, None):
|
||||
needs_update = True
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and
|
||||
# swapped (a<->b)
|
||||
for a_index, b_index, direct in batch_update.moved:
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if (a_entry := self.min_toks.pop(a_index,
|
||||
None)) is None:
|
||||
if self.min_toks.pop(b_index, None) is not None:
|
||||
needs_update = True
|
||||
else:
|
||||
self.min_toks[b_index] = a_entry
|
||||
needs_update = True
|
||||
else:
|
||||
a_entry = self.min_toks.pop(a_index, None)
|
||||
if (b_entry := self.min_toks.pop(b_index,
|
||||
None)) is not None:
|
||||
self.min_toks[a_index] = b_entry
|
||||
needs_update = True
|
||||
if a_entry is not None:
|
||||
self.min_toks[b_index] = a_entry
|
||||
needs_update = True
|
||||
|
||||
if self.min_toks:
|
||||
# Check for any requests that have attained their min tokens.
|
||||
to_remove = tuple(index for index, (min_toks, out_tok_ids,
|
||||
_) in self.min_toks.items()
|
||||
if len(out_tok_ids) >= min_toks)
|
||||
if to_remove:
|
||||
needs_update = True
|
||||
for index in to_remove:
|
||||
del self.min_toks[index]
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs: list[int] = []
|
||||
tok_ids: list[int] = []
|
||||
for req, (_, _, stop_tok_ids) in self.min_toks.items():
|
||||
reqs.extend([req] * len(stop_tok_ids))
|
||||
tok_ids.extend(stop_tok_ids)
|
||||
|
||||
self.logits_slice = (self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32))
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (torch.tensor(data,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory).to(device=self.device,
|
||||
non_blocking=True))
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.min_toks:
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits[self.logits_slice] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
|
||||
device: torch.device) -> LogitsProcessorManager:
|
||||
"""Construct 'builtin' vLLM logitsprocs which the engine
|
||||
loads by default.
|
||||
|
||||
Args:
|
||||
pin_memory_available: pinned memory is available for use
|
||||
for use by logitsproc
|
||||
max_num_reqs: ceiling on request count in persistent batch
|
||||
device: inference device
|
||||
|
||||
Returns:
|
||||
Data structure encapsulating loaded logitsprocs
|
||||
"""
|
||||
min_tokens_logitproc = MinTokensLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
logit_bias_logitproc = LogitBiasLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
min_p_logitproc = MinPLogitsProcessor(
|
||||
pin_memory=pin_memory_available,
|
||||
device=device,
|
||||
# +1 for temporary swap space
|
||||
max_num_reqs=max_num_reqs + 1)
|
||||
return LogitsProcessorManager(
|
||||
non_argmax_invariant=[
|
||||
min_tokens_logitproc,
|
||||
logit_bias_logitproc,
|
||||
],
|
||||
argmax_invariant=[min_p_logitproc],
|
||||
)
|
||||
43
vllm/v1/sample/metadata.py
Normal file
43
vllm/v1/sample/metadata.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: Optional[torch.Tensor]
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessorManager
|
||||
0
vllm/v1/sample/ops/__init__.py
Normal file
0
vllm/v1/sample/ops/__init__.py
Normal file
39
vllm/v1/sample/ops/bad_words.py
Normal file
39
vllm/v1/sample/ops/bad_words.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
if prefix_length > 0:
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
else:
|
||||
actual_prefix = []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids,
|
||||
past_tokens_ids[i])
|
||||
43
vllm/v1/sample/ops/penalties.py
Normal file
43
vllm/v1/sample/ops/penalties.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
296
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
296
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import flashinfer.sampling
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
Module that performs optional top-k and top-p filtering followed by
|
||||
weighted random sampling of logits.
|
||||
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if is_flashinfer_available:
|
||||
flashinfer_version = flashinfer.__version__
|
||||
if flashinfer_version < "0.2.3":
|
||||
logger.warning(
|
||||
"FlashInfer version >= 0.2.3 required. "
|
||||
"Falling back to default sampling implementation.")
|
||||
self.forward = self.forward_native
|
||||
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||
# default it is unused). For backward compatibility, we set
|
||||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||
# interpret it differently in V0 and V1 samplers: In V0,
|
||||
# None means False, while in V1, None means True. This is
|
||||
# why we use the condition
|
||||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is available, but it is not enabled. "
|
||||
"Falling back to the PyTorch-native implementation of "
|
||||
"top-p & top-k sampling. For the best performance, "
|
||||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of top-p & top-k sampling. For the "
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_tpu():
|
||||
self.forward = self.forward_tpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
if generators:
|
||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
logits = apply_top_k_top_p_tpu(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
||||
def apply_top_k_top_p_tpu(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
|
||||
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (k is None and p is None)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
631
vllm/v1/sample/rejection_sampler.py
Normal file
631
vllm/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
"""
|
||||
The implementation strictly follows the algorithm described in
|
||||
https://arxiv.org/abs/2211.17192.
|
||||
However, we want to clarify the terminology used in the implementation:
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
between the "raw" draft and target probabilities.
|
||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
||||
distribution, which is derived from both the draft and target
|
||||
probabilities.
|
||||
bonus tokens:
|
||||
If all proposed tokens are accepted, the bonus token is added to the
|
||||
end of the sequence. The bonus token is only sampled from the target
|
||||
probabilities. We pass in the bonus tokens instead of sampling them
|
||||
in the rejection sampler to allow for more flexibility in the
|
||||
sampling process. For example, we can use top_p, top_k sampling for
|
||||
bonus tokens, while spec decode does not support these sampling
|
||||
strategies.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_probs = compute_probs(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
(batch_size, max_spec_len + 1),
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def compute_probs(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Compute probability distribution from logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits and converts
|
||||
them to probabilities using softmax. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be converted to probabilities.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
||||
if non-greedy sampling is used, otherwise returns the
|
||||
original logits.
|
||||
"""
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
|
||||
num_tokens = logits.shape[0]
|
||||
temperature = expand_batch_to_tokens(
|
||||
sampling_metadata.temperature,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
replace_from=GREEDY_TEMPERATURE,
|
||||
replace_to=1,
|
||||
)
|
||||
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
|
||||
logits.div_(temperature.unsqueeze(-1))
|
||||
|
||||
# Get expanded top_k and top_p tensors.
|
||||
top_k = None
|
||||
if sampling_metadata.top_k is not None:
|
||||
top_k = expand_batch_to_tokens(
|
||||
sampling_metadata.top_k,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
top_p = None
|
||||
if sampling_metadata.top_p is not None:
|
||||
top_p = expand_batch_to_tokens(
|
||||
sampling_metadata.top_p,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return output_prob
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
x: torch.Tensor, # [batch_size]
|
||||
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||
num_tokens: int,
|
||||
replace_from: int = 0,
|
||||
replace_to: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||
tokens per batch in cu_num_tokens.
|
||||
|
||||
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||
|
||||
Args:
|
||||
x: [batch_size] tensor to expand.
|
||||
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||
tokens per batch. Each element represents the total number of
|
||||
tokens up to and including that batch.
|
||||
num_tokens: Total number of tokens.
|
||||
replace_from: int = 0
|
||||
Value to be replaced if it is found in x.
|
||||
replace_to: int = 0
|
||||
Value to replace with when replace_from is found.
|
||||
Returns:
|
||||
expanded_x: [num_tokens] tensor.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
expand_kernel[(batch_size, )](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
num_warps=1,
|
||||
)
|
||||
return expanded_x
|
||||
|
||||
|
||||
def generate_uniform_probs(
|
||||
num_tokens: int,
|
||||
num_draft_tokens: list[int],
|
||||
generators: dict[int, torch.Generator],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
if available.
|
||||
|
||||
This method creates a tensor of shape `(num_tokens, )` filled
|
||||
with uniform random values in the range [0, 1). If `generators` is provided,
|
||||
the requests with their own seeds will use the provided `torch.Generator`
|
||||
for reproducibility. The samples for the other requests will be generated
|
||||
without a seed.
|
||||
|
||||
Args:
|
||||
num_tokens : int
|
||||
Total number of tokens.
|
||||
num_draft_tokens : List[List[int]]
|
||||
Number of draft tokens per request.
|
||||
generators : Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(num_tokens, )` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
start_idx = 0
|
||||
for req_idx, n in enumerate(num_draft_tokens):
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if n == 0:
|
||||
continue
|
||||
end_idx = start_idx + n
|
||||
generator = generators.get(req_idx)
|
||||
if generator is not None:
|
||||
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
|
||||
start_idx = end_idx
|
||||
return uniform_probs
|
||||
|
||||
|
||||
def sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
triton.next_power_of_2(vocab_size),
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
|
||||
# re-compilation may happen during runtime when is_greedy_ptr is None.
|
||||
if is_greedy_ptr is None:
|
||||
is_greedy = True
|
||||
else:
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if not is_greedy:
|
||||
# Early exit for non-greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if is_greedy:
|
||||
# Early exit for greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
token_id)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||
def expand_kernel(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0: # noqa: SIM108
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset,
|
||||
src_val,
|
||||
mask=offset < num_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sample_recovered_tokens_kernel(
|
||||
output_token_ids_ptr, # [num_tokens]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
# Early exit for out-of-range positions.
|
||||
pos = tl.program_id(1)
|
||||
if pos >= num_draft_tokens:
|
||||
return
|
||||
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
0)
|
||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"))
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
226
vllm/v1/sample/sampler.py
Normal file
226
vllm/v1/sample/sampler.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# TODO(rob): provide option for logprobs post sampling.
|
||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply allowed token ids.
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling
|
||||
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||
# with subsequent operations that may use these values as indices.
|
||||
# This conversion is necessary because FlashInfer sampling operations
|
||||
# return int32 (while PyTorch argmax and topk return int64).
|
||||
sampled = sampled.long()
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
The various logits processing functions called in this method
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
else:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
if sampling_metadata.all_greedy:
|
||||
return greedy_sampled
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
if greedy_sampled is None:
|
||||
return random_sampled
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == torch.int64
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
0
vllm/v1/sample/tpu/__init__.py
Normal file
0
vllm/v1/sample/tpu/__init__.py
Normal file
124
vllm/v1/sample/tpu/metadata.py
Normal file
124
vllm/v1/sample/tpu/metadata.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
# repetition_penalties=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor = None
|
||||
|
||||
min_p: torch.Tensor = None
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
all_greedy: bool = True
|
||||
|
||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||
# when gathering logprobs, regardless of the values specified in the batch.
|
||||
logprobs: bool = False
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
prompt_token_ids = None
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
# should use tensor
|
||||
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||
|
||||
min_tokens = None # impl is not vectorized
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]] = field(
|
||||
default_factory=lambda: list())
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
|
||||
# Generator not supported by xla
|
||||
_generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
@property
|
||||
def generators(self) -> dict[int, torch.Generator]:
|
||||
# Generator not supported by torch/xla. This field must be immutable.
|
||||
return self._generators
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls,
|
||||
input_batch: InputBatch,
|
||||
padded_num_reqs: int,
|
||||
xla_device: torch.device,
|
||||
generate_params_if_all_greedy: bool = False
|
||||
) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
|
||||
|
||||
Args:
|
||||
input_batch: The input batch containing sampling parameters.
|
||||
padded_num_reqs: The padded number of requests.
|
||||
xla_device: The XLA device.
|
||||
generate_params_if_all_greedy: If True, generate sampling parameters
|
||||
even if all requests are greedy. this is useful for cases where
|
||||
we want to pre-compile a graph with sampling parameters, even if
|
||||
they are not strictly needed for greedy decoding.
|
||||
"""
|
||||
needs_logprobs = input_batch.max_num_logprobs>0 if \
|
||||
input_batch.max_num_logprobs else False
|
||||
# Early return to avoid unnecessary cpu to tpu copy
|
||||
if (input_batch.all_greedy is True
|
||||
and generate_params_if_all_greedy is False):
|
||||
return cls(all_greedy=True, logprobs=needs_logprobs)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
|
||||
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
|
||||
fill_slice(input_batch.temperature_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
fill_slice(input_batch.min_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
fill_slice(input_batch.top_k_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
fill_slice(input_batch.top_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
|
||||
to(xla_device),
|
||||
all_greedy=input_batch.all_greedy,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
logprobs=needs_logprobs)
|
||||
145
vllm/v1/sample/tpu/sampler.py
Normal file
145
vllm/v1/sample/tpu/sampler.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# These are TPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=None)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled, random_sampled)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logits: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
Reference in New Issue
Block a user