Sync from v0.13
This commit is contained in:
0
vllm/v1/sample/__init__.py
Normal file
0
vllm/v1/sample/__init__.py
Normal file
352
vllm/v1/sample/logits_processor/__init__.py
Normal file
352
vllm/v1/sample/logits_processor/__init__.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import inspect
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.torch_utils import guard_cuda_initialization
|
||||
from vllm.v1.sample.logits_processor.builtin import (
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
process_dict_updates,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.interface import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a pooling model
|
||||
# and custom logitsproces
|
||||
STR_POOLING_REJECTS_LOGITSPROCS = (
|
||||
"Pooling models do not support custom logits processors."
|
||||
)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a speculative
|
||||
# decoding enabled and custom logitsproces
|
||||
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
|
||||
"Custom logits processors are not supported when speculative decoding is enabled."
|
||||
)
|
||||
|
||||
LOGITSPROCS_GROUP = "vllm.logits_processors"
|
||||
|
||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||
MinTokensLogitsProcessor,
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
]
|
||||
|
||||
|
||||
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
|
||||
"""Load all installed logit processor plugins"""
|
||||
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
|
||||
if len(installed_logitsprocs_plugins) == 0:
|
||||
logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
|
||||
return []
|
||||
|
||||
# Load logitsprocs plugins
|
||||
logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP)
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for entrypoint in installed_logitsprocs_plugins:
|
||||
try:
|
||||
logger.debug(
|
||||
"- Loading logitproc plugin entrypoint=%s target=%s",
|
||||
entrypoint.name,
|
||||
entrypoint.value,
|
||||
)
|
||||
with guard_cuda_initialization():
|
||||
classes.append(entrypoint.load())
|
||||
except Exception as e:
|
||||
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
|
||||
raise RuntimeError(
|
||||
f"Failed to load LogitsProcessor plugin {entrypoint}"
|
||||
) from e
|
||||
return classes
|
||||
|
||||
|
||||
def _load_logitsprocs_by_fqcns(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load logit processor types, identifying them by fully-qualified class
|
||||
names (FQCNs).
|
||||
|
||||
Effectively, a mixed list of logitproc types and FQCN strings is converted
|
||||
into a list of entirely logitproc types, by loading from the FQCNs.
|
||||
|
||||
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
|
||||
|
||||
Already-loaded logitproc types must be subclasses of LogitsProcessor
|
||||
|
||||
Args:
|
||||
logits_processors: Potentially mixed list of logitsprocs types and FQCN
|
||||
strings for logitproc types
|
||||
|
||||
Returns:
|
||||
List of logitproc types
|
||||
|
||||
"""
|
||||
if not logits_processors:
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
"%s additional custom logits processors specified, checking whether "
|
||||
"they need to be loaded.",
|
||||
len(logits_processors),
|
||||
)
|
||||
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for ldx, logitproc in enumerate(logits_processors):
|
||||
if isinstance(logitproc, type):
|
||||
logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__)
|
||||
if not issubclass(logitproc, LogitsProcessor):
|
||||
raise ValueError(
|
||||
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
|
||||
)
|
||||
classes.append(logitproc)
|
||||
continue
|
||||
|
||||
logger.debug("- Loading logits processor %s", logitproc)
|
||||
module_path, qualname = logitproc.split(":")
|
||||
|
||||
try:
|
||||
# Load module
|
||||
with guard_cuda_initialization():
|
||||
module = importlib.import_module(module_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load %sth LogitsProcessor plugin %s: %s",
|
||||
ldx,
|
||||
logitproc,
|
||||
e,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
|
||||
) from e
|
||||
|
||||
# Walk down dotted name to get logitproc class
|
||||
obj = module
|
||||
for attr in qualname.split("."):
|
||||
obj = getattr(obj, attr)
|
||||
if not isinstance(obj, type):
|
||||
raise ValueError("Loaded logit processor must be a type.")
|
||||
if not issubclass(obj, LogitsProcessor):
|
||||
raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
|
||||
classes.append(obj)
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def _load_custom_logitsprocs(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load all custom logits processors.
|
||||
|
||||
* First load all installed logitproc plugins
|
||||
* Second load custom logitsprocs pass by the user at initialization time
|
||||
|
||||
Args:
|
||||
logits_processors: potentially mixed list of logitproc types and
|
||||
logitproc type fully-qualified names (FQCNs)
|
||||
which need to be loaded
|
||||
|
||||
Returns:
|
||||
A list of all loaded logitproc types
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# No logitsprocs specified by caller
|
||||
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
|
||||
return []
|
||||
|
||||
return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)
|
||||
|
||||
|
||||
def build_logitsprocs(
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
is_pin_memory: bool,
|
||||
is_pooling_model: bool,
|
||||
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (),
|
||||
) -> LogitsProcessors:
|
||||
if is_pooling_model:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
||||
logger.debug(
|
||||
"Skipping logits processor loading because pooling models"
|
||||
" do not support logits processors."
|
||||
)
|
||||
return LogitsProcessors()
|
||||
|
||||
# Check if speculative decoding is enabled.
|
||||
if vllm_config.speculative_config:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
|
||||
logger.warning(
|
||||
"min_p, logit_bias, and min_tokens parameters won't currently work "
|
||||
"with speculative decoding enabled."
|
||||
)
|
||||
return LogitsProcessors()
|
||||
|
||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
ctor(vllm_config, device, is_pin_memory)
|
||||
for ctor in itertools.chain(
|
||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
|
||||
|
||||
|
||||
def validate_logits_processors_parameters(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
logits_processors = (
|
||||
tuple(logits_processors) if logits_processors is not None else None
|
||||
)
|
||||
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
|
||||
logits_procs.validate_params(sampling_params)
|
||||
|
||||
|
||||
class AdapterLogitsProcessor(LogitsProcessor):
|
||||
"""Wrapper for per-request logits processors
|
||||
|
||||
To wrap a specific per-request logits processor,
|
||||
* Subclass `AdapterLogitsProcessor`
|
||||
* Implement `self.is_argmax_invariant()` base-class method
|
||||
* Implement `self.new_req_logits_processor(params)`
|
||||
|
||||
`self.__init__(vllm_config, device, is_pin_memory)` does not need to be
|
||||
overridden in general. However, to implement custom constructor behavior -
|
||||
especially any logic which operates on or stores `vllm_config`, `device`,
|
||||
or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)`
|
||||
must be overridden and the override must call
|
||||
`super().__init__(vllm_config, device, is_pin_memory)`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
"""Subclass must invoke
|
||||
`super().__init__(vllm_config, device, is_pin_memory)`.
|
||||
|
||||
Subclass constructor may find it useful to utilize the `vllm_config`,
|
||||
`device` and `is_pin_memory` argument. However regardless of whether
|
||||
these arguments are used, the vLLM logits processor interface requires
|
||||
all three arguments to be present.
|
||||
"""
|
||||
|
||||
# Map req index -> logits processor state
|
||||
#
|
||||
# State representation is a partial[Tensor] comprising a request-level
|
||||
# logits processor with the output token ids argument and (if required)
|
||||
# the prompt token ids argument pre-populated
|
||||
#
|
||||
# Note that the partial carries a *reference* to output token ids, and
|
||||
# will thus always operate on the list as it is currently, not as it
|
||||
# was when the partial was created.
|
||||
self.req_info: dict[int, partial[torch.Tensor]] = {}
|
||||
|
||||
@abstractmethod
|
||||
def new_req_logits_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> RequestLogitsProcessor | None:
|
||||
"""Consume request info; return a per-request logits processor.
|
||||
|
||||
Return None if logits processor does not need to be applied to request
|
||||
|
||||
Args:
|
||||
params: request sampling params
|
||||
|
||||
Returns:
|
||||
None if logits processor should not be applied to request; otherwise
|
||||
returns a `RequestLogitsProcessor` instance
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_state(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
prompt_ids: list[int] | None,
|
||||
output_ids: list[int],
|
||||
) -> partial[torch.Tensor] | None:
|
||||
"""Return state representation for new request
|
||||
|
||||
Returns None if logits processor is not applicable to request
|
||||
|
||||
Args:
|
||||
params: request sampling params
|
||||
prompt_ids: request prompt token ids
|
||||
output_ids: decoded tokens so far for this request
|
||||
|
||||
Returns:
|
||||
logits processor partial[Tensor] or None
|
||||
|
||||
"""
|
||||
if req_lp := self.new_req_logits_processor(params):
|
||||
args = (
|
||||
[prompt_ids, output_ids]
|
||||
if (len(inspect.signature(req_lp).parameters) == 3)
|
||||
else [output_ids]
|
||||
)
|
||||
return partial(req_lp, *args) # type: ignore[misc]
|
||||
return None
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
self._new_state,
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.req_info:
|
||||
# Apply per-request logits processors to corresponding rows of
|
||||
# logits tensor
|
||||
for req_idx, req_lp in self.req_info.items():
|
||||
req_logits = logits[req_idx]
|
||||
new_logits = req_lp(req_logits)
|
||||
if new_logits is not req_logits:
|
||||
# Modify logits tensor row in-place if necessary
|
||||
logits[req_idx] = new_logits
|
||||
return logits
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LogitsProcessor",
|
||||
"LogitBiasLogitsProcessor",
|
||||
"MinPLogitsProcessor",
|
||||
"MinTokensLogitsProcessor",
|
||||
"BatchUpdate",
|
||||
"BatchUpdateBuilder",
|
||||
"MoveDirectionality",
|
||||
"LogitsProcessors",
|
||||
"build_logitsprocs",
|
||||
"STR_POOLING_REJECTS_LOGITSPROCS",
|
||||
"LOGITSPROCS_GROUP",
|
||||
"AdapterLogitsProcessor",
|
||||
]
|
||||
278
vllm/v1/sample/logits_processor/builtin.py
Normal file
278
vllm/v1/sample/logits_processor/builtin.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.v1.sample.logits_processor.interface import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MinPLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory
|
||||
)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
|
||||
self.use_double_tensor = torch.device(device).type != "cpu"
|
||||
|
||||
if self.use_double_tensor:
|
||||
# Pre-allocated device tensor
|
||||
self.min_p_device: torch.Tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
self.min_p_device = self.min_p_cpu_tensor
|
||||
# Current slice of the device tensor
|
||||
self.min_p: torch.Tensor = self.min_p_device[:0]
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Min-p never impacts greedy sampling"""
|
||||
return True
|
||||
|
||||
def get_min_p_by_index(self, index: int) -> float:
|
||||
return float(self.min_p_cpu[index])
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
needs_update = False
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
min_p = params.min_p
|
||||
min_p_before = self.min_p_cpu[index]
|
||||
if min_p_before != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
if min_p and not min_p_before:
|
||||
self.min_p_count += 1
|
||||
elif not min_p and min_p_before:
|
||||
self.min_p_count -= 1
|
||||
|
||||
if self.min_p_count:
|
||||
# Process removed requests.
|
||||
if batch_update.removed:
|
||||
needs_update = True
|
||||
for index in batch_update.removed:
|
||||
if self.min_p_cpu[index]:
|
||||
self.min_p_cpu[index] = 0
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b).
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
|
||||
if min_p_a != min_p_b:
|
||||
needs_update = True
|
||||
self.min_p_cpu[bdx] = min_p_a
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
self.min_p_cpu[adx] = min_p_b
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if min_p_a:
|
||||
self.min_p_cpu[adx] = 0
|
||||
if min_p_b:
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Update tensors if needed.
|
||||
size = batch_update.batch_size
|
||||
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
|
||||
self.min_p = self.min_p_device[:size]
|
||||
if self.use_double_tensor:
|
||||
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
|
||||
self.min_p.unsqueeze_(1)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.min_p_count:
|
||||
return logits
|
||||
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
|
||||
# Adjust min_p
|
||||
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
||||
# Identify valid tokens using threshold comparison
|
||||
invalid_token_mask = probability_values < adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits.masked_fill_(invalid_token_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, _, device: torch.device, is_pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
|
||||
self.bias_tensor: torch.Tensor = torch.tensor(())
|
||||
self.logits_slice = (
|
||||
self._device_tensor([], torch.int32),
|
||||
self._device_tensor([], torch.int32),
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Logit bias can rebalance token probabilities and change the
|
||||
outcome of argmax in greedy sampling."""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
needs_update = process_dict_updates(
|
||||
self.biases, batch_update, lambda params, _, __: params.logit_bias or None
|
||||
)
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs: list[int] = []
|
||||
tok_ids: list[int] = []
|
||||
biases: list[float] = []
|
||||
for req, lb in self.biases.items():
|
||||
reqs.extend([req] * len(lb))
|
||||
tok_ids.extend(lb.keys())
|
||||
biases.extend(lb.values())
|
||||
|
||||
self.bias_tensor = self._device_tensor(biases, torch.float32)
|
||||
self.logits_slice = (
|
||||
self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32),
|
||||
)
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.tensor(
|
||||
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
|
||||
).to(device=self.device, non_blocking=True)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.biases:
|
||||
logits[self.logits_slice] += self.bias_tensor
|
||||
return logits
|
||||
|
||||
|
||||
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
# index -> (min_toks, output_token_ids, stop_token_ids)
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
|
||||
# (req_idx_tensor,eos_tok_id_tensor)
|
||||
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
|
||||
self._device_tensor([], torch.int32),
|
||||
self._device_tensor([], torch.int32),
|
||||
)
|
||||
|
||||
self.neg_inf_tensor = torch.tensor(
|
||||
-float("inf"), dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""By censoring stop tokens, min-tokens can change the outcome
|
||||
of the argmax operation in greedy sampling."""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def add_request(
|
||||
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
|
||||
) -> tuple[int, Sequence[int], set[int]] | None:
|
||||
min_tokens = params.min_tokens
|
||||
if not min_tokens or len(output_tok_ids) >= min_tokens:
|
||||
return None
|
||||
return min_tokens, output_tok_ids, params.all_stop_token_ids
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
needs_update = process_dict_updates(
|
||||
self.min_toks, batch_update, self.add_request
|
||||
)
|
||||
if self.min_toks:
|
||||
# Check for any requests that have attained their min tokens.
|
||||
to_remove = tuple(
|
||||
index
|
||||
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
|
||||
if len(out_tok_ids) >= min_toks
|
||||
)
|
||||
if to_remove:
|
||||
needs_update = True
|
||||
for index in to_remove:
|
||||
del self.min_toks[index]
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs: list[int] = []
|
||||
tok_ids: list[int] = []
|
||||
for req, (_, _, stop_tok_ids) in self.min_toks.items():
|
||||
reqs.extend([req] * len(stop_tok_ids))
|
||||
tok_ids.extend(stop_tok_ids)
|
||||
|
||||
self.logits_slice = (
|
||||
self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32),
|
||||
)
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.tensor(
|
||||
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
|
||||
).to(device=self.device, non_blocking=True)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.min_toks:
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
|
||||
return logits
|
||||
|
||||
|
||||
def process_dict_updates(
|
||||
req_entries: dict[int, T],
|
||||
batch_update: BatchUpdate | None,
|
||||
new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
|
||||
) -> bool:
|
||||
"""Utility function to update dict state for sparse LogitsProcessors."""
|
||||
|
||||
if not batch_update:
|
||||
# Nothing to do.
|
||||
return False
|
||||
|
||||
updated = False
|
||||
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
|
||||
if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
|
||||
req_entries[index] = state
|
||||
updated = True
|
||||
elif req_entries.pop(index, None) is not None:
|
||||
updated = True
|
||||
|
||||
if req_entries:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
if req_entries.pop(index, None):
|
||||
updated = True
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and
|
||||
# swapped (a<->b)
|
||||
for a_index, b_index, direct in batch_update.moved:
|
||||
a_entry = req_entries.pop(a_index, None)
|
||||
b_entry = req_entries.pop(b_index, None)
|
||||
if a_entry is not None:
|
||||
req_entries[b_index] = a_entry
|
||||
updated = True
|
||||
if b_entry is not None:
|
||||
updated = True
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
req_entries[a_index] = b_entry
|
||||
|
||||
return updated
|
||||
106
vllm/v1/sample/logits_processor/interface.py
Normal file
106
vllm/v1/sample/logits_processor/interface.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = auto()
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = auto()
|
||||
|
||||
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]]
|
||||
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Key assumption: the `output_tok_ids` list (which is an element of each
|
||||
# tuple in `added`) is a reference to the request's running output tokens
|
||||
# list; via this reference, the logits processors always see the latest
|
||||
# list of generated output tokens.
|
||||
#
|
||||
# NOTE:
|
||||
# * Added or moved requests may replace existing requests with the same
|
||||
# index.
|
||||
# * Operations should be processed in the following order:
|
||||
# - removed, added, moved
|
||||
removed: Sequence[RemovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
@classmethod
|
||||
def validate_params(cls, sampling_params: SamplingParams):
|
||||
"""Validate sampling params for this logits processor.
|
||||
|
||||
Raise ValueError for invalid ones.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply LogitsProcessor to batch logits tensor.
|
||||
|
||||
The updated tensor must be returned but may be
|
||||
modified in-place.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional["BatchUpdate"],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update: Non-None iff there have been changes
|
||||
to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
165
vllm/v1/sample/logits_processor/state.py
Normal file
165
vllm/v1/sample/logits_processor/state.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.v1.sample.logits_processor.interface import (
|
||||
AddedRequest,
|
||||
BatchUpdate,
|
||||
MovedRequest,
|
||||
RemovedRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
added: list[AddedRequest]
|
||||
moved: list[MovedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: list[RemovedRequest] | None = None,
|
||||
added: list[AddedRequest] | None = None,
|
||||
moved: list[MovedRequest] | None = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.added = added or []
|
||||
self.moved = moved or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
# Used to track changes in the pooling case
|
||||
# where we don't populate the added list.
|
||||
self.batch_changed = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from the persistent batch.
|
||||
|
||||
Must not be called after the first time self.removed,
|
||||
self.pop_removed() or self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError(
|
||||
"Cannot register new removed request after self.removed has been read."
|
||||
)
|
||||
self._removed.append(index)
|
||||
self.batch_changed = True
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> int | None:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> int | None:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def reset(self) -> bool:
|
||||
"""Returns True if there were any changes to the batch."""
|
||||
self._is_removed_sorted = False
|
||||
self._removed.clear()
|
||||
self.added.clear()
|
||||
self.moved.clear()
|
||||
batch_changed = self.batch_changed
|
||||
self.batch_changed = False
|
||||
return batch_changed
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
|
||||
"""Generate a logitsprocs batch update data structure and reset
|
||||
internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
self.batch_changed = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessors:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
|
||||
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
|
||||
self.argmax_invariant: list[LogitsProcessor] = []
|
||||
self.non_argmax_invariant: list[LogitsProcessor] = []
|
||||
if logitsprocs:
|
||||
for logitproc in logitsprocs:
|
||||
(
|
||||
self.argmax_invariant
|
||||
if logitproc.is_argmax_invariant()
|
||||
else self.non_argmax_invariant
|
||||
).append(logitproc)
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator["LogitsProcessor"]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
44
vllm/v1/sample/metadata.py
Normal file
44
vllm/v1/sample/metadata.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
temperature: torch.Tensor | None
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: torch.Tensor | None
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: torch.Tensor | None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
# Speculative token ids
|
||||
spec_token_ids: list[list[int]] | None = None
|
||||
0
vllm/v1/sample/ops/__init__.py
Normal file
0
vllm/v1/sample/ops/__init__.py
Normal file
52
vllm/v1/sample/ops/bad_words.py
Normal file
52
vllm/v1/sample/ops/bad_words.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
|
||||
|
||||
|
||||
def apply_bad_words_with_drafts(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> None:
|
||||
start_idx = 0
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
for draft_idx in range(num_draft_tokens[i]):
|
||||
_apply_bad_words_single_batch(
|
||||
logits[start_idx + draft_idx],
|
||||
bad_words_ids,
|
||||
past_tokens_ids[start_idx + draft_idx],
|
||||
)
|
||||
start_idx += num_draft_tokens[i]
|
||||
25
vllm/v1/sample/ops/logprobs.py
Normal file
25
vllm/v1/sample/ops/logprobs.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Some utilities for logprobs, including logits."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Counts elements in each row of x that are greater than the corresponding
|
||||
value in values. Use torch.compile to generate an optimized kernel for
|
||||
this function. otherwise, it will create additional copies of the input
|
||||
tensors and cause memory issues.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
||||
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
||||
"""
|
||||
return (x >= values).sum(-1)
|
||||
57
vllm/v1/sample/ops/penalties.py
Normal file
57
vllm/v1/sample/ops/penalties.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
|
||||
|
||||
# In the async scheduling case, rows that won't have penalties applied may contain
|
||||
# -1 placeholder token ids. We must replace these with valid token ids so that the
|
||||
# scatter done in apply_penalties is valid.
|
||||
# NOTE(nick): The penalties implementation is currently quite inefficient and
|
||||
# will be reworked anyhow.
|
||||
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
|
||||
|
||||
return apply_penalties(
|
||||
logits,
|
||||
prompt_token_ids,
|
||||
output_tokens_t,
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
repetition_penalties,
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_tensors(
|
||||
output_token_ids: list[list[int]], vocab_size: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
384
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
384
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
Module that performs optional top-k and top-p filtering followed by
|
||||
weighted random sampling of logits.
|
||||
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
# flashinfer optimization does not apply if intermediate
|
||||
# logprobs/logits after top_k/top_p need to be returned
|
||||
if (
|
||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
if not FlashInferBackend.supports_compute_capability(capability):
|
||||
capability_str = capability.as_version_str()
|
||||
raise RuntimeError(
|
||||
"FlashInfer does not support compute capability "
|
||||
f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
|
||||
)
|
||||
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.",
|
||||
scope="global",
|
||||
)
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.debug_once(
|
||||
"FlashInfer top-p/top-k sampling is available but disabled "
|
||||
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
|
||||
"after verifying accuracy for your workloads."
|
||||
)
|
||||
self.forward = self.forward_native
|
||||
|
||||
elif current_platform.is_cpu():
|
||||
arch = current_platform.get_cpu_architecture()
|
||||
# Fall back to native implementation for POWERPC and RISCV.
|
||||
# On PowerPC argmax produces incorrect output with torch.compile.
|
||||
# PR: https://github.com/vllm-project/vllm/pull/26987
|
||||
if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_cpu
|
||||
elif (
|
||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
try:
|
||||
import aiter.ops.sampling # noqa: F401
|
||||
|
||||
self.aiter_ops = torch.ops.aiter
|
||||
logger.info_once(
|
||||
"Using aiter sampler on ROCm (lazy import, sampling-only)."
|
||||
)
|
||||
self.forward = self.forward_hip
|
||||
except ImportError:
|
||||
logger.warning_once(
|
||||
"aiter.ops.sampling is not available on ROCm. "
|
||||
"Falling back to forward_native implementation."
|
||||
)
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.debug_once(
|
||||
"FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
|
||||
"FlashInfer does not support returning logits/logprobs"
|
||||
)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling for CPU.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
if len(generators) != logits.shape[0]:
|
||||
return compiled_random_sample(logits), logits_to_return
|
||||
else:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.warning_once(
|
||||
"aiter sampler does not support per-request generators; "
|
||||
"falling back to PyTorch-native."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in (
|
||||
"processed_logits",
|
||||
"processed_logprobs",
|
||||
), "aiter sampler does not support returning logits/logprobs."
|
||||
return self.aiter_sample(logits, k, p, generators), None
|
||||
|
||||
def aiter_sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from logits using aiter ops."""
|
||||
use_top_k = k is not None
|
||||
use_top_p = p is not None
|
||||
# Joint k+p path
|
||||
if use_top_p and use_top_k:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
None,
|
||||
*_to_tensor_scalar_tuple(k),
|
||||
*_to_tensor_scalar_tuple(p),
|
||||
deterministic=True,
|
||||
)
|
||||
return next_token_ids.view(-1)
|
||||
# Top-p only path
|
||||
elif use_top_p:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
|
||||
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
|
||||
)
|
||||
return next_token_ids.view(-1)
|
||||
# Top-k only path
|
||||
elif use_top_k:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||
renorm_probs = self.aiter_ops.top_k_renorm_probs(
|
||||
probs, *_to_tensor_scalar_tuple(k)
|
||||
)
|
||||
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
|
||||
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
|
||||
|
||||
|
||||
# Note: this is a workaround for
|
||||
# https://github.com/pytorch/pytorch/pull/151218
|
||||
@torch.compile(dynamic=True)
|
||||
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
return probs.div(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
|
||||
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
|
||||
raise ImportError(
|
||||
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
|
||||
)
|
||||
|
||||
assert not (k is None and p is None)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True
|
||||
)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True
|
||||
)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, deterministic=True
|
||||
)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
805
vllm/v1/sample/rejection_sampler.py
Normal file
805
vllm/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,805 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = 0
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 128
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
"""
|
||||
The implementation strictly follows the algorithm described in
|
||||
https://arxiv.org/abs/2211.17192.
|
||||
However, we want to clarify the terminology used in the implementation:
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
between the "raw" draft and target probabilities.
|
||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
||||
distribution, which is derived from both the draft and target
|
||||
probabilities.
|
||||
bonus tokens:
|
||||
If all proposed tokens are accepted, the bonus token is added to the
|
||||
end of the sequence. The bonus token is only sampled from the target
|
||||
probabilities. We pass in the bonus tokens instead of sampling them
|
||||
in the rejection sampler to allow for more flexibility in the
|
||||
sampling process. For example, we can use top_p, top_k sampling for
|
||||
bonus tokens, while spec decode does not support these sampling
|
||||
strategies.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Sampler):
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
logprobs_mode = self.sampler.logprobs_mode
|
||||
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
|
||||
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens + batch_size, vocab_size]
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens + batch_size, vocab_size]. Here,
|
||||
probabilities from different requests are flattened into a
|
||||
single tensor because this is the shape of the output logits.
|
||||
NOTE: `logits` can be updated in place to save memory.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
SamplerOutput:
|
||||
Contains the final output token IDs and their logprobs if
|
||||
requested.
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
|
||||
bonus_logits_indices = metadata.bonus_logits_indices
|
||||
target_logits_indices = metadata.target_logits_indices
|
||||
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[bonus_logits_indices]
|
||||
bonus_sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=replace(
|
||||
sampling_metadata,
|
||||
max_num_logprobs=-1,
|
||||
),
|
||||
predict_bonus_token=True,
|
||||
# Override the logprobs mode to return logits because they are
|
||||
# needed later to compute the accepted token logprobs.
|
||||
logprobs_mode_override="processed_logits"
|
||||
if self.is_processed_logprobs_mode
|
||||
else "raw_logits",
|
||||
)
|
||||
bonus_token_ids = bonus_sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
raw_target_logits = logits[target_logits_indices]
|
||||
# Use float32 for the target_logits.
|
||||
raw_target_logits = raw_target_logits.to(torch.float32)
|
||||
target_logits = self.apply_logits_processors(
|
||||
raw_target_logits, sampling_metadata, metadata
|
||||
)
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `apply_sampling_constraints` function.
|
||||
target_logits = apply_sampling_constraints(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
# Compute probability distribution from target logits.
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
logprobs_tensors = None
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
logprobs_tensors = self._get_logprobs_tensors(
|
||||
sampling_metadata.max_num_logprobs,
|
||||
metadata,
|
||||
logits,
|
||||
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
|
||||
bonus_sampler_output.logprobs_tensors.logprobs,
|
||||
output_token_ids,
|
||||
)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=output_token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
|
||||
def _get_logprobs_tensors(
|
||||
self,
|
||||
max_num_logprobs: int,
|
||||
metadata: SpecDecodeMetadata,
|
||||
logits: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
bonus_logits: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
|
||||
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
|
||||
|
||||
# Collect target and bonus logits.
|
||||
bonus_logits_indices = metadata.bonus_logits_indices
|
||||
target_logits_indices = metadata.target_logits_indices
|
||||
final_logits = torch.zeros_like(logits, dtype=torch.float32)
|
||||
final_logits[target_logits_indices] = target_logits.to(torch.float32)
|
||||
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
|
||||
|
||||
# Compute accepted token indices.
|
||||
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
|
||||
num_accepted_tokens = accepted_mask.sum(dim=-1)
|
||||
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
|
||||
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
|
||||
num_accepted_tokens
|
||||
)
|
||||
|
||||
# Compute logprobs for accepted tokens.
|
||||
accepted_logits = final_logits[accepted_logit_indices]
|
||||
accepted_logprobs = (
|
||||
accepted_logits
|
||||
if self.is_logits_logprobs_mode
|
||||
else self.sampler.compute_logprobs(accepted_logits)
|
||||
)
|
||||
accepted_tokens = sampled_token_ids[accepted_mask]
|
||||
return self.sampler.gather_logprobs(
|
||||
accepted_logprobs,
|
||||
max_num_logprobs,
|
||||
accepted_tokens.to(torch.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
discard_req_indices: Sequence[int] = (),
|
||||
return_cu_num_tokens: bool = False,
|
||||
) -> tuple[list[list[int]], list[int] | None]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
discard_req_indices: Optional row indices to discard tokens in.
|
||||
return_cu_num_tokens: Whether to also return cumulative token counts.
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
cu_num_tokens = None
|
||||
if return_cu_num_tokens:
|
||||
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
|
||||
if len(discard_req_indices) > 0:
|
||||
valid_mask[discard_req_indices] = False
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs, cu_num_tokens
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
) -> torch.Tensor:
|
||||
has_penalties = not sampling_metadata.no_penalties
|
||||
any_penalties_or_bad_words = (
|
||||
sampling_metadata.bad_words_token_ids or has_penalties
|
||||
)
|
||||
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
if any_penalties_or_bad_words:
|
||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||
output_token_ids,
|
||||
sampling_metadata.spec_token_ids,
|
||||
)
|
||||
|
||||
# Calculate indices of target logits.
|
||||
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
|
||||
num_requests = len(sampling_metadata.output_token_ids)
|
||||
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||
original_indices = torch.arange(num_requests, device="cpu")
|
||||
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
|
||||
repeat_indices = repeat_indices_cpu.to(
|
||||
device=logits.device, non_blocking=True
|
||||
)
|
||||
logits = self.apply_penalties(
|
||||
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
|
||||
)
|
||||
|
||||
# Apply allowed token ids.
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
|
||||
logits.masked_fill_(token_mask, float("-inf"))
|
||||
|
||||
# Apply bad words exclusion.
|
||||
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words_with_drafts(
|
||||
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
repeat_indices: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.no_penalties:
|
||||
return logits
|
||||
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
|
||||
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
|
||||
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
|
||||
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
|
||||
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
|
||||
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
prompt_token_ids,
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
repetition_penalties,
|
||||
output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _combine_outputs_with_spec_tokens(
|
||||
output_token_ids: list[list[int]],
|
||||
spec_token_ids: list[list[int]] | None = None,
|
||||
) -> list[list[int]]:
|
||||
if spec_token_ids is None:
|
||||
return output_token_ids
|
||||
|
||||
result = []
|
||||
for out, spec in zip(output_token_ids, spec_token_ids):
|
||||
if len(spec) == 0:
|
||||
continue
|
||||
result.append(out)
|
||||
for i in range(len(spec) - 1):
|
||||
result.append([*result[-1], spec[i]])
|
||||
return result
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.full(
|
||||
(batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID,
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def apply_sampling_constraints(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Process logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits,
|
||||
as well as top-k and top-p. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be processed.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed logits if non-greedy sampling is used,
|
||||
otherwise returns the original logits.
|
||||
"""
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
|
||||
num_tokens = logits.shape[0]
|
||||
temperature = expand_batch_to_tokens(
|
||||
sampling_metadata.temperature,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
replace_from=GREEDY_TEMPERATURE,
|
||||
replace_to=1,
|
||||
)
|
||||
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
|
||||
logits.div_(temperature.unsqueeze(-1))
|
||||
|
||||
# Get expanded top_k and top_p tensors.
|
||||
top_k = None
|
||||
if sampling_metadata.top_k is not None:
|
||||
top_k = expand_batch_to_tokens(
|
||||
sampling_metadata.top_k,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
top_p = None
|
||||
if sampling_metadata.top_p is not None:
|
||||
top_p = expand_batch_to_tokens(
|
||||
sampling_metadata.top_p,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
x: torch.Tensor, # [batch_size]
|
||||
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||
num_tokens: int,
|
||||
replace_from: int = 0,
|
||||
replace_to: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||
tokens per batch in cu_num_tokens.
|
||||
|
||||
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||
|
||||
Args:
|
||||
x: [batch_size] tensor to expand.
|
||||
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||
tokens per batch. Each element represents the total number of
|
||||
tokens up to and including that batch.
|
||||
num_tokens: Total number of tokens.
|
||||
replace_from: int = 0
|
||||
Value to be replaced if it is found in x.
|
||||
replace_to: int = 0
|
||||
Value to replace with when replace_from is found.
|
||||
Returns:
|
||||
expanded_x: [num_tokens] tensor.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
expand_kernel[(batch_size,)](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
)
|
||||
return expanded_x
|
||||
|
||||
|
||||
def generate_uniform_probs(
|
||||
num_tokens: int,
|
||||
num_draft_tokens: list[int],
|
||||
generators: dict[int, torch.Generator],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
if available.
|
||||
|
||||
This method creates a tensor of shape `(num_tokens, )` filled
|
||||
with uniform random values in the range [0, 1). If `generators` is provided,
|
||||
the requests with their own seeds will use the provided `torch.Generator`
|
||||
for reproducibility. The samples for the other requests will be generated
|
||||
without a seed.
|
||||
|
||||
Args:
|
||||
num_tokens: int
|
||||
Total number of tokens.
|
||||
num_draft_tokens: List[List[int]]
|
||||
Number of draft tokens per request.
|
||||
generators: Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects.
|
||||
device: torch.device
|
||||
The device on which to allocate the tensor.
|
||||
Returns:
|
||||
uniform_rand: torch.Tensor
|
||||
A tensor of shape `(num_tokens, )` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
# NOTE(woosuk): We deliberately use float64 instead of float32 here
|
||||
# because when using float32, there's a non-negligible chance that
|
||||
# uniform_prob is sampled to be exact 0.0 as reported in
|
||||
# https://github.com/pytorch/pytorch/issues/16706. Using float64
|
||||
# mitigates the issue.
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens,),
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
start_idx = 0
|
||||
for req_idx, n in enumerate(num_draft_tokens):
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if n == 0:
|
||||
continue
|
||||
end_idx = start_idx + n
|
||||
generator = generators.get(req_idx)
|
||||
if generator is not None:
|
||||
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
|
||||
start_idx = end_idx
|
||||
return uniform_probs
|
||||
|
||||
|
||||
def sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
triton.next_power_of_2(vocab_size),
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
|
||||
# re-compilation may happen during runtime when is_greedy_ptr is None.
|
||||
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
|
||||
if not is_greedy:
|
||||
# Early exit for non-greedy sampling requests.
|
||||
return
|
||||
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id,
|
||||
)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if is_greedy:
|
||||
# Early exit for greedy sampling requests.
|
||||
return
|
||||
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
|
||||
)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
|
||||
)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||
def expand_kernel(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0: # noqa: SIM108
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sample_recovered_tokens_kernel(
|
||||
output_token_ids_ptr, # [num_tokens]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
# Early exit for out-of-range positions.
|
||||
pos = tl.program_id(1)
|
||||
if pos >= num_draft_tokens:
|
||||
return
|
||||
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
|
||||
other=0,
|
||||
)
|
||||
else:
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
q = tl.load(
|
||||
q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"),
|
||||
)
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
319
vllm/v1/sample/sampler.py
Normal file
319
vllm/v1/sample/sampler.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""
|
||||
A layer that samples the next tokens from the model's outputs
|
||||
with the following steps in order:
|
||||
|
||||
1. If logprobs are requested:
|
||||
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
|
||||
as the final logprobs to return.
|
||||
b) If `logprobs_mode` is `raw_logits`, clone the logits
|
||||
as the final logprobs to return.
|
||||
2. Convert logits to float32.
|
||||
3. Apply allowed token ids whitelist.
|
||||
4. Apply bad words exclusion.
|
||||
5. Apply logit processors which are not argmax-invariant,
|
||||
i.e. that can impact greedy sampling.
|
||||
a) Min tokens processor
|
||||
b) Logit bias processor
|
||||
6. Apply penalties
|
||||
a) Repetition penalty
|
||||
b) Frequency penalty
|
||||
c) Presence penalty
|
||||
7. Sample the next tokens. `sample` method performs the following steps:
|
||||
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
|
||||
return the greedily sampled tokens and final logprobs if requested.
|
||||
b) Apply temperature.
|
||||
c) Apply logit processors which are argmax-invariant, by default
|
||||
the min_p processor.
|
||||
d) Apply top_k and/or top_p.
|
||||
e) Sample the next tokens with the probability distribution.
|
||||
f) If `all_random` or temperature >= epsilon (1e-5), return the
|
||||
randomly sampled tokens and final logprobs if requested. Else,
|
||||
return the greedily sampled tokens and logprobs if requested.
|
||||
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
|
||||
(if requested). Note that if the sampled token is within the top
|
||||
`max_num_logprobs`, the logprob will be eventually merged in
|
||||
`LogprobsProcessor` during output processing. Therefore, the
|
||||
final output may contain either `max_num_logprobs + 1` or
|
||||
`max_num_logprobs` logprobs.
|
||||
9. Return the final `SamplerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
predict_bonus_token: bool = False,
|
||||
logprobs_mode_override: LogprobsMode | None = None,
|
||||
) -> SamplerOutput:
|
||||
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
if logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
elif logprobs_mode == "raw_logits":
|
||||
if logits.dtype == torch.float32:
|
||||
raw_logprobs = logits.clone()
|
||||
else:
|
||||
raw_logprobs = logits.to(torch.float32)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
|
||||
logits = self.apply_logits_processors(
|
||||
logits, sampling_metadata, predict_bonus_token
|
||||
)
|
||||
# Sample the next token.
|
||||
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
||||
if processed_logprobs is not None:
|
||||
raw_logprobs = processed_logprobs
|
||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||
# with subsequent operations that may use these values as indices.
|
||||
# This conversion is necessary because FlashInfer sampling operations
|
||||
# return int32 (while PyTorch argmax and topk return int64).
|
||||
sampled = sampled.long()
|
||||
|
||||
if num_logprobs is None:
|
||||
logprobs_tensors = None
|
||||
elif num_logprobs == -1:
|
||||
# Return the full unsorted and unranked logprobs.
|
||||
logprobs_tensors = LogprobsTensors(
|
||||
torch.empty(0), raw_logprobs, torch.empty(0)
|
||||
)
|
||||
else:
|
||||
# Gather the logprobs and ranks of the topk and sampled token.
|
||||
logprobs_tensors = self.gather_logprobs(
|
||||
raw_logprobs, num_logprobs, token_ids=sampled
|
||||
)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
@staticmethod
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
all_random: bool,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
# Avoid division by zero if there are greedy requests.
|
||||
if not all_random:
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
@staticmethod
|
||||
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
logprobs_mode_override: LogprobsMode | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
The various logits processing functions called in this method
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
else:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
if sampling_metadata.all_greedy:
|
||||
processed_logprobs = None
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if logprobs_mode == "processed_logits":
|
||||
processed_logprobs = logits
|
||||
elif logprobs_mode == "processed_logprobs":
|
||||
processed_logprobs = self.compute_logprobs(logits)
|
||||
return greedy_sampled, processed_logprobs
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(
|
||||
logits, sampling_metadata.temperature, sampling_metadata.all_random
|
||||
)
|
||||
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
if greedy_sampled is None:
|
||||
return random_sampled, processed_logprobs
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled, processed_logprobs
|
||||
|
||||
@staticmethod
|
||||
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def gather_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == torch.int64
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
@staticmethod
|
||||
def _combine_outputs_with_spec_tokens(
|
||||
output_token_ids: list[list[int]],
|
||||
spec_token_ids: list[list[int]] | None = None,
|
||||
) -> list[list[int]]:
|
||||
if spec_token_ids is None:
|
||||
return output_token_ids
|
||||
|
||||
return [
|
||||
[*out, *spec] if spec else out
|
||||
for out, spec in zip(output_token_ids, spec_token_ids)
|
||||
]
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
predict_bonus_token: bool,
|
||||
) -> torch.Tensor:
|
||||
bad_words_token_ids = sampling_metadata.bad_words_token_ids
|
||||
any_penalties_or_bad_words = (
|
||||
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
|
||||
)
|
||||
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
if predict_bonus_token and any_penalties_or_bad_words:
|
||||
# Combine base outputs with spec tokens when speculative decoding
|
||||
# is enabled.
|
||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||
output_token_ids,
|
||||
sampling_metadata.spec_token_ids,
|
||||
)
|
||||
|
||||
# Apply allowed token ids.
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
||||
|
||||
# Apply bad words exclusion.
|
||||
if bad_words_token_ids:
|
||||
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling.
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.no_penalties:
|
||||
return logits
|
||||
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
return apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
output_token_ids,
|
||||
)
|
||||
0
vllm/v1/sample/tpu/__init__.py
Normal file
0
vllm/v1/sample/tpu/__init__.py
Normal file
120
vllm/v1/sample/tpu/metadata.py
Normal file
120
vllm/v1/sample/tpu/metadata.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
# repetition_penalties=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor = None
|
||||
|
||||
min_p: torch.Tensor = None
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
all_greedy: bool = True
|
||||
all_random: bool = False
|
||||
|
||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||
# when gathering logprobs, regardless of the values specified in the batch.
|
||||
logprobs: bool = False
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
prompt_token_ids = None
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
# should use tensor
|
||||
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||
|
||||
min_tokens = None # impl is not vectorized
|
||||
|
||||
logit_bias: list[dict[int, float] | None] = field(default_factory=lambda: list())
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
|
||||
# Generator not supported by xla
|
||||
_generators: dict[int, torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
@property
|
||||
def generators(self) -> dict[int, torch.Generator]:
|
||||
# Generator not supported by torch/xla. This field must be immutable.
|
||||
return self._generators
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls,
|
||||
input_batch: InputBatch,
|
||||
padded_num_reqs: int,
|
||||
xla_device: torch.device,
|
||||
generate_params_if_all_greedy: bool = False,
|
||||
) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
|
||||
|
||||
Args:
|
||||
input_batch: The input batch containing sampling parameters.
|
||||
padded_num_reqs: The padded number of requests.
|
||||
xla_device: The XLA device.
|
||||
generate_params_if_all_greedy: If True, generate sampling parameters
|
||||
even if all requests are greedy. this is useful for cases where
|
||||
we want to pre-compile a graph with sampling parameters, even if
|
||||
they are not strictly needed for greedy decoding.
|
||||
"""
|
||||
needs_logprobs = (
|
||||
input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False
|
||||
)
|
||||
# Early return to avoid unnecessary cpu to tpu copy
|
||||
if input_batch.all_greedy is True and generate_params_if_all_greedy is False:
|
||||
return cls(all_greedy=True, logprobs=needs_logprobs)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
|
||||
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
|
||||
fill_slice(
|
||||
input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
fill_slice(input_batch.top_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device
|
||||
),
|
||||
all_greedy=input_batch.all_greedy,
|
||||
all_random=input_batch.all_random,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||
logprobs=needs_logprobs,
|
||||
)
|
||||
215
vllm/v1/sample/tpu/sampler.py
Normal file
215
vllm/v1/sample/tpu/sampler.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
# TODO(houseroad): Add support for logprobs_mode.
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# These are TPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
all_random: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
|
||||
if not all_random:
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(
|
||||
logits, sampling_metadata.temperature, sampling_metadata.all_random
|
||||
)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
# Random sample.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
random_sampled = self.random_sample(probs, sampling_metadata.generators)
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
q.exponential_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multiple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
Reference in New Issue
Block a user