542 lines
24 KiB
Python
542 lines
24 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
import itertools
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from importlib.util import find_spec
|
|
from math import inf
|
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
|
|
import msgspec
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import vllm.envs as envs
|
|
from vllm.model_executor.layers.utils import apply_penalties
|
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
|
SamplingTensors,
|
|
SequenceGroupToSample)
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
|
CompletionSequenceGroupOutput, Logprob,
|
|
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
|
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
|
_apply_min_tokens_penalty,
|
|
_apply_top_k_top_p,
|
|
_apply_min_p,
|
|
_sample,
|
|
SampleResultArgsType,
|
|
get_logprobs,
|
|
_build_sampler_output,
|
|
SampleReturnType,
|
|
SampleResultsDictType,
|
|
SampleMetadataType,
|
|
MultinomialSamplesType,
|
|
_modify_greedy_probs_inplace,
|
|
_top_k_top_p_multinomial_with_flashinfer,
|
|
_multinomial,
|
|
get_pythonized_sample_results,
|
|
)
|
|
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER as use_ds3_sampler
|
|
from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER_OP as use_ds3_sampler_op
|
|
|
|
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
|
import flashinfer.sampling
|
|
# yapf: disable
|
|
from flashinfer.sampling import (
|
|
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
|
|
|
# yapf: enable
|
|
else:
|
|
flashinfer_top_k_top_p_sampling = None
|
|
|
|
class SamplerOutput(
|
|
msgspec.Struct,
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
array_like=True): # type: ignore[call-arg]
|
|
"""For each sequence group, we generate a list of SequenceOutput object,
|
|
each of which contains one possible candidate for the next token.
|
|
|
|
This data structure implements methods, so it can be used like a list, but
|
|
also has optional fields for device tensors.
|
|
"""
|
|
|
|
outputs: List[CompletionSequenceGroupOutput]
|
|
|
|
# On-device tensor containing probabilities of each token.
|
|
sampled_token_probs: Optional[torch.Tensor] = None
|
|
|
|
# On-device tensor containing the logprobs of each token.
|
|
logprobs: Optional["torch.Tensor"] = None
|
|
|
|
# Holds either (1) the pythonized sampler result (single-step scheduling)
|
|
# or (2) what will be arguments for later deferred pythonization of the
|
|
# sampler result (muliti-step scheduling)
|
|
deferred_sample_results_args: Optional[SampleResultArgsType] = None
|
|
|
|
# On-device tensor containing the sampled token ids.
|
|
sampled_token_ids: Optional[torch.Tensor] = None
|
|
# CPU tensor containing the sampled token ids. Used during multi-step to
|
|
# return the sampled token ids from last rank to AsyncLLMEngine to be
|
|
# 'broadcasted' to all other PP ranks for next step.
|
|
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
|
|
|
# On-device tensor containing the sampled token embeddings (embeddings
|
|
# corresponding to the sampled token ids). Used when prompt embeddings are
|
|
# specified in lieu of prompt token ids or text.
|
|
sampled_token_embeds: Optional[torch.Tensor] = None
|
|
|
|
# Optional last hidden states from the model.
|
|
hidden_states: Optional[torch.Tensor] = None
|
|
|
|
# Optional prefill hidden states from the model
|
|
# (used for models like EAGLE).
|
|
prefill_hidden_states: Optional[torch.Tensor] = None
|
|
|
|
# Time taken in the forward pass for this across all workers
|
|
model_forward_time: Optional[float] = None
|
|
|
|
# Time taken in the model execute function. This will include model forward,
|
|
# block/sync across workers, cpu-gpu sync time and sampling time.
|
|
model_execute_time: Optional[float] = None
|
|
|
|
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
|
|
return self.outputs[idx]
|
|
|
|
def __setitem__(self, idx: int, value):
|
|
self.outputs[idx] = value
|
|
|
|
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
|
|
return iter(self.outputs)
|
|
|
|
def __len__(self):
|
|
return len(self.outputs)
|
|
|
|
def __eq__(self, other: object):
|
|
return isinstance(other,
|
|
self.__class__) and self.outputs == other.outputs
|
|
|
|
def __repr__(self) -> str:
|
|
"""Show the shape of a tensor instead of its values to reduce noise.
|
|
"""
|
|
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
|
else self.sampled_token_probs.shape)
|
|
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
|
self.sampled_token_ids.shape)
|
|
return (
|
|
f"SamplerOutput(outputs={self.outputs}, "
|
|
f"sampled_token_probs={sampled_token_probs_repr}, "
|
|
f"sampled_token_ids={sampled_token_ids_repr},")
|
|
|
|
|
|
def Sampler_forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
"""
|
|
Single-step scheduling:
|
|
* Perform GPU-side sampling computation & compute
|
|
GPU-side logprobs tensor
|
|
* Pythonize sampling result & logprobs tensor
|
|
|
|
Multi-step scheduling:
|
|
* Perform GPU-side sampling computation & compute
|
|
GPU-side logprobs tensor
|
|
* Defer Pythonization of sampling result & logprobs
|
|
tensor
|
|
* Encapsulate arguments required for deferred Pythonization
|
|
in the :class:`SamplerOutput` structure
|
|
|
|
Args:
|
|
logits: (num_tokens, vocab_size).
|
|
sampling_metadata: Metadata for sampling.
|
|
"""
|
|
|
|
assert logits is not None
|
|
# print(f'Sampler_forward all_greedy={all_greedy}')
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
if not sampling_metadata.reuse_sampling_tensors:
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
elif self._do_penalties:
|
|
# In this case, the sampling tensors logic depends on
|
|
# "output_tokens" of a sequence. As a result, we cannot
|
|
# reuse sampling tensors, since "output_tokens" changes
|
|
# between decode runs.
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
|
|
assert self._sampling_tensors is not None
|
|
sampling_tensors = self._sampling_tensors
|
|
do_penalties = self._do_penalties
|
|
do_top_p_top_k = self._do_top_p_top_k
|
|
do_min_p = self._do_min_p
|
|
|
|
is_greedy = (len(sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]) == logits.shape[0])
|
|
is_random = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM]) == logits.shape[0])
|
|
is_random_seed = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM_SEED]) == logits.shape[0])
|
|
|
|
|
|
max_n_in_batch = sampling_metadata.seq_groups[0].sampling_params.n
|
|
generator = sampling_metadata.seq_groups[0].generator
|
|
min_tokens = sampling_metadata.seq_groups[0].sampling_params.min_tokens
|
|
# print("use_ds3_sampler ", use_ds3_sampler)
|
|
if use_ds3_sampler == True and (is_greedy == True or ((is_random == True or is_random_seed == True) \
|
|
and do_penalties == False \
|
|
and flashinfer_top_k_top_p_sampling is None \
|
|
and min_tokens <= 0 \
|
|
and do_min_p == False \
|
|
and max_n_in_batch == 1 \
|
|
# and self._should_modify_greedy_probs_inplace == False
|
|
# and self.include_gpu_probs_tensor == False
|
|
)):
|
|
sampling_type = SamplingType.GREEDY
|
|
sample_metadata: SampleMetadataType = {}
|
|
multinomial_samples: MultinomialSamplesType = {}
|
|
greedy_samples: Optional[torch.Tensor] = None
|
|
multinomial_out: Optional[torch.Tensor] = None
|
|
vacc_device = logits.device
|
|
# Create output tensor for sampled token ids.
|
|
if self.include_gpu_probs_tensor:
|
|
sampled_token_ids_tensor = torch.full((logits.shape[0], 1),
|
|
VLLM_INVALID_TOKEN_ID,
|
|
dtype=torch.long,
|
|
device=vacc_device)
|
|
probs_out = torch.empty_like(logits)
|
|
logprobs_out = torch.empty_like(logits)
|
|
else:
|
|
probs_out = None
|
|
logprobs_out = None
|
|
sampled_token_ids_tensor = None
|
|
if is_greedy == True:
|
|
greedy_samples, _ = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 0)
|
|
sampling_type = SamplingType.GREEDY
|
|
if sampled_token_ids_tensor is not None:
|
|
# Store sampled tokens in output tensor.
|
|
sampled_token_ids_tensor = greedy_samples.unsqueeze(-1).to(torch.long)
|
|
if probs_out is not None:
|
|
# probs_out = torch.softmax(logits.to(torch.float), dim=-1, dtype=torch.float).to(logits)
|
|
probs_out = torch.softmax(logits, dim=-1)
|
|
if self._should_modify_greedy_probs_inplace == True:
|
|
sample_indices = (sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]).long()
|
|
probs_out[sample_indices, :] = 0
|
|
probs_out[sample_indices, greedy_samples] = 1.0
|
|
elif is_random == True and do_top_p_top_k == True:
|
|
if use_ds3_sampler_op:
|
|
logits = logits.to(torch.float)
|
|
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 2)
|
|
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
|
|
else:
|
|
logits = logits.to(torch.float)
|
|
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
|
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
probs_out = probs
|
|
# multinomial_out = torch.multinomial(probs, 1)
|
|
q = torch.empty_like(probs)
|
|
q.exponential_()
|
|
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
|
|
sampling_type = SamplingType.RANDOM
|
|
elif is_random_seed == True and generator is not None and do_top_p_top_k == True:
|
|
if use_ds3_sampler_op:
|
|
# print("is_random_seed ", is_random_seed)
|
|
logits = logits.to(torch.float)
|
|
multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 1, generator)
|
|
multinomial_out = multinomial_out.view(-1, max_n_in_batch)
|
|
else:
|
|
logits = logits.to(torch.float)
|
|
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
|
logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks).to(torch.float)
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
probs_out = probs
|
|
# torch.manual_seed(sampling_metadata.seq_groups[0].sampling_params.seed)
|
|
# multinomial_out = torch.multinomial(probs, 1)
|
|
q = torch.empty_like(probs)
|
|
q.exponential_(generator=generator)
|
|
multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch)
|
|
sampling_type = SamplingType.RANDOM_SEED
|
|
|
|
multinomial_samples[sampling_type] = multinomial_out
|
|
|
|
if sampled_token_ids_tensor is not None:
|
|
if(sampling_type != SamplingType.GREEDY):
|
|
# Store sampled tokens in output tensor.
|
|
sampled_token_ids_tensor = multinomial_samples[sampling_type].to(torch.long)
|
|
|
|
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
|
t: []
|
|
for t in SamplingType
|
|
}
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
sampling_params = seq_group.sampling_params
|
|
sampling_type = sampling_params.sampling_type
|
|
categorized_seq_group_ids[sampling_type].append(i)
|
|
|
|
seq_group_id = categorized_seq_group_ids[sampling_type]
|
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
|
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
|
sample_results_dict: SampleResultsDictType = {}
|
|
|
|
maybe_deferred_args = SampleResultArgsType(
|
|
sampling_metadata=sampling_metadata,
|
|
sample_metadata=sample_metadata,
|
|
multinomial_samples=multinomial_samples,
|
|
greedy_samples=greedy_samples,
|
|
# beam_search_logprobs=None,
|
|
sample_results_dict=sample_results_dict)
|
|
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
# GPU<->CPU sync happens here.
|
|
# This also converts the sampler output to a Python object.
|
|
# Return Pythonized sampler result & sampled token ids
|
|
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = get_pythonized_sample_results(
|
|
maybe_deferred_args), sampled_token_ids_tensor
|
|
else:
|
|
# Defer sampler result Pythonization; return deferred
|
|
# Pythonization args & sampled token ids
|
|
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = (
|
|
maybe_deferred_args,
|
|
sampled_token_ids_tensor,
|
|
)
|
|
|
|
if self.include_gpu_probs_tensor:
|
|
on_device_tensors = (probs_out, logprobs_out, maybe_sampled_tokens_tensor)
|
|
else:
|
|
on_device_tensors = None
|
|
# Get the logprobs query results.
|
|
prompt_logprobs = None
|
|
sample_logprobs = None
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
|
assert not isinstance(maybe_deferred_sample_results,
|
|
SampleResultArgsType)
|
|
logprobs = logits
|
|
prompt_logprobs, sample_logprobs = get_logprobs(
|
|
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
|
|
|
return _build_sampler_output(
|
|
maybe_deferred_sample_results,
|
|
sampling_metadata,
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
on_device_tensors=on_device_tensors,
|
|
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
|
|
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
|
|
|
# Apply presence and frequency penalties.
|
|
# if do_penalties:
|
|
# logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
# sampling_tensors.output_tokens,
|
|
# sampling_tensors.presence_penalties.to(logits.device),
|
|
# sampling_tensors.frequency_penalties.to(logits.device),
|
|
# sampling_tensors.repetition_penalties.to(logits.device))
|
|
|
|
# Use float32 to apply temperature scaling.
|
|
# Use in-place division to avoid creating a new tensor.
|
|
logits = logits.to(torch.float)
|
|
logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1))
|
|
|
|
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps.to(logits.device),
|
|
sampling_tensors.top_ks.to(logits.device))
|
|
|
|
if do_min_p:
|
|
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
# Compute the probabilities.
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
# Compute the log probabilities.
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
# Sample the next tokens.
|
|
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
probs,
|
|
logprobs,
|
|
sampling_metadata,
|
|
sampling_tensors,
|
|
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
|
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
|
)
|
|
|
|
if self.include_gpu_probs_tensor:
|
|
# Since we will defer sampler result Pythonization,
|
|
# preserve GPU-side tensors in support of later
|
|
# deferred pythonization of logprobs
|
|
assert maybe_sampled_tokens_tensor is not None
|
|
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
|
else:
|
|
# Since Pythonization has already happened, don't preserve
|
|
# GPU-side tensors.
|
|
on_device_tensors = None
|
|
|
|
# Get the logprobs query results.
|
|
prompt_logprobs = None
|
|
sample_logprobs = None
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
|
assert not isinstance(maybe_deferred_sample_results,
|
|
SampleResultArgsType)
|
|
prompt_logprobs, sample_logprobs = get_logprobs(
|
|
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
|
|
|
return _build_sampler_output(
|
|
maybe_deferred_sample_results,
|
|
sampling_metadata,
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
on_device_tensors=on_device_tensors,
|
|
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
|
|
|
def rejection_forward(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
|
) -> torch.Tensor:
|
|
if seeded_seqs is None:
|
|
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 1)
|
|
else:
|
|
out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 0, seeded_seqs[0])
|
|
return out
|
|
|
|
class Sampler(nn.Module):
|
|
def forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
"""
|
|
Single-step scheduling:
|
|
* Perform GPU-side sampling computation & compute
|
|
GPU-side logprobs tensor
|
|
* Pythonize sampling result & logprobs tensor
|
|
|
|
Multi-step scheduling:
|
|
* Perform GPU-side sampling computation & compute
|
|
GPU-side logprobs tensor
|
|
* Defer Pythonization of sampling result & logprobs
|
|
tensor
|
|
* Encapsulate arguments required for deferred Pythonization
|
|
in the :class:`SamplerOutput` structure
|
|
|
|
Args:
|
|
logits: (num_tokens, vocab_size).
|
|
sampling_metadata: Metadata for sampling.
|
|
"""
|
|
assert logits is not None
|
|
_, vocab_size = logits.shape
|
|
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
if not sampling_metadata.reuse_sampling_tensors:
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
elif self._do_penalties:
|
|
# In this case, the sampling tensors logic depends on
|
|
# "output_tokens" of a sequence. As a result, we cannot
|
|
# reuse sampling tensors, since "output_tokens" changes
|
|
# between decode runs.
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
|
|
assert self._sampling_tensors is not None
|
|
sampling_tensors = self._sampling_tensors
|
|
do_penalties = self._do_penalties
|
|
do_top_p_top_k = self._do_top_p_top_k
|
|
do_min_p = self._do_min_p
|
|
|
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
|
|
|
# Apply presence and frequency penalties.
|
|
if do_penalties:
|
|
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
sampling_tensors.output_tokens,
|
|
sampling_tensors.presence_penalties,
|
|
sampling_tensors.frequency_penalties,
|
|
sampling_tensors.repetition_penalties)
|
|
|
|
# Use float32 to apply temperature scaling.
|
|
# Use in-place division to avoid creating a new tensor.
|
|
logits = logits.to(torch.float)
|
|
# print("tempratures is:", temperatures)
|
|
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1).to(logits.device))
|
|
|
|
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
|
sampling_tensors.top_ks)
|
|
|
|
if do_min_p:
|
|
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
# Compute the probabilities.
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
# Compute the log probabilities.
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
# Sample the next tokens.
|
|
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
probs,
|
|
logprobs,
|
|
sampling_metadata,
|
|
sampling_tensors,
|
|
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
|
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
|
)
|
|
|
|
if self.include_gpu_probs_tensor:
|
|
# Since we will defer sampler result Pythonization,
|
|
# preserve GPU-side tensors in support of later
|
|
# deferred pythonization of logprobs
|
|
assert maybe_sampled_tokens_tensor is not None
|
|
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
|
else:
|
|
# Since Pythonization has already happened, don't preserve
|
|
# GPU-side tensors.
|
|
on_device_tensors = None
|
|
|
|
# Get the logprobs query results.
|
|
prompt_logprobs = None
|
|
sample_logprobs = None
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
|
assert not isinstance(maybe_deferred_sample_results,
|
|
SampleResultArgsType)
|
|
prompt_logprobs, sample_logprobs = get_logprobs(
|
|
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
|
|
|
return _build_sampler_output(
|
|
maybe_deferred_sample_results,
|
|
sampling_metadata,
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
on_device_tensors=on_device_tensors,
|
|
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
|
|
|
def _apply_top_k_top_p_vacc(
|
|
logits: torch.Tensor,
|
|
p: torch.Tensor,
|
|
k: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
|
|
|
# Apply top-k.
|
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
|
# Get all the top_k values.
|
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
|
top_k_mask = logits_sort < top_k_mask
|
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
|
|
|
# Apply top-p.
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1).to(probs_sum.device)
|
|
# at least one
|
|
top_p_mask[:, -1] = False
|
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
|
|
|
# Re-sort the probabilities.
|
|
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
|
|
index=logits_idx,
|
|
src=logits_sort)
|
|
return logits |