500 lines
20 KiB
Python
500 lines
20 KiB
Python
from importlib.util import find_spec
|
|
from typing import Dict, List, Optional, Tuple
|
|
import torch
|
|
|
|
from vllm import envs
|
|
from vllm.distributed.parallel_state import get_tp_group
|
|
from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, MultinomialSamplesType, SampleMetadataType, \
|
|
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
|
|
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, \
|
|
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
|
|
from vllm.model_executor.layers.utils import apply_penalties
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, PromptLogprobs, SampleLogprobs, SequenceOutput
|
|
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
|
import flashinfer.sampling
|
|
# yapf: disable
|
|
from flashinfer.sampling import (
|
|
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
|
# yapf: enable
|
|
else:
|
|
flashinfer_top_k_top_p_sampling = None
|
|
|
|
class SampleRecorder:
|
|
def __init__(self):
|
|
self.seq_ids:torch.Tensor = None
|
|
self.sampled_token_ids_tensor:torch.Tensor = None
|
|
|
|
last_sampler = None
|
|
|
|
def get_last_sampler():
|
|
return last_sampler
|
|
|
|
class ZeroOverheadSampler(Sampler):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
"""
|
|
Single-step scheduling:
|
|
* Perform GPU-side sampling computation & compute
|
|
GPU-side logprobs tensor
|
|
* Pythonize sampling result & logprobs tensor
|
|
|
|
Multi-step scheduling:
|
|
* Perform GPU-side sampling computation & compute
|
|
GPU-side logprobs tensor
|
|
* Defer Pythonization of sampling result & logprobs
|
|
tensor
|
|
* Encapsulate arguments required for deferred Pythonization
|
|
in the :class:`SamplerOutput` structure
|
|
|
|
Args:
|
|
logits: (num_tokens, vocab_size).
|
|
sampling_metadata: Metadata for sampling.
|
|
"""
|
|
global last_sampler
|
|
last_sampler = SampleRecorder()
|
|
assert logits is not None
|
|
_, vocab_size = logits.shape
|
|
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
if not sampling_metadata.reuse_sampling_tensors:
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
elif self._do_penalties:
|
|
# In this case, the sampling tensors logic depends on
|
|
# "output_tokens" of a sequence. As a result, we cannot
|
|
# reuse sampling tensors, since "output_tokens" changes
|
|
# between decode runs.
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
|
|
assert self._sampling_tensors is not None
|
|
sampling_tensors = self._sampling_tensors
|
|
do_penalties = self._do_penalties
|
|
do_top_p_top_k = self._do_top_p_top_k
|
|
do_min_p = self._do_min_p
|
|
|
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
|
|
|
# Apply presence and frequency penalties.
|
|
if do_penalties:
|
|
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
sampling_tensors.output_tokens,
|
|
sampling_tensors.presence_penalties,
|
|
sampling_tensors.frequency_penalties,
|
|
sampling_tensors.repetition_penalties)
|
|
|
|
# Use float32 to apply temperature scaling.
|
|
# Use in-place division to avoid creating a new tensor.
|
|
logits = logits.to(torch.float)
|
|
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
|
|
|
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
|
sampling_tensors.top_ks)
|
|
|
|
if do_min_p:
|
|
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
# Compute the probabilities.
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
# Compute the log probabilities.
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
# Sample the next tokens.
|
|
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
probs,
|
|
logprobs,
|
|
sampling_metadata,
|
|
sampling_tensors,
|
|
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
|
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
|
)
|
|
|
|
if self.include_gpu_probs_tensor:
|
|
# Since we will defer sampler result Pythonization,
|
|
# preserve GPU-side tensors in support of later
|
|
# deferred pythonization of logprobs
|
|
assert maybe_sampled_tokens_tensor is not None
|
|
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
|
else:
|
|
# Since Pythonization has already happened, don't preserve
|
|
# GPU-side tensors.
|
|
on_device_tensors = None
|
|
|
|
# Get the logprobs query results.
|
|
prompt_logprobs = None
|
|
sample_logprobs = None
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
|
assert not isinstance(maybe_deferred_sample_results,
|
|
SampleResultArgsType)
|
|
prompt_logprobs, sample_logprobs = get_logprobs(
|
|
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
|
|
|
return _build_sampler_output(
|
|
maybe_deferred_sample_results,
|
|
sampling_metadata,
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
on_device_tensors=on_device_tensors,
|
|
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
|
|
logits=logits)
|
|
|
|
def _greedy_sample(
|
|
selected_seq_groups: List[SequenceGroupToSample],
|
|
samples: torch.Tensor,
|
|
) -> SampleResultType:
|
|
"""Run greedy sampling on a given samples.
|
|
|
|
Args:
|
|
selected_seq_groups: A list of sequence groups batched.
|
|
samples: (num_selected_samples,) A tensor of samples. The length of
|
|
samples could be smaller than selected_seq_groups if
|
|
seq_group.do_sample is False.
|
|
Returns:
|
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
|
same as the length of selected_seq_groups. If the corresponding
|
|
seq_group has do_sample=False, tuple contains ([], [])
|
|
"""
|
|
sample_idx = 0
|
|
results: SampleResultType = []
|
|
for seq_group in selected_seq_groups:
|
|
if not seq_group.do_sample:
|
|
results.append(([], []))
|
|
continue
|
|
|
|
seq_ids = seq_group.seq_ids
|
|
num_parent_seqs = len(seq_ids)
|
|
assert num_parent_seqs == 1, (
|
|
"Greedy sampling should have only one seq.")
|
|
parent_ids = list(range(num_parent_seqs))
|
|
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
|
next_token_ids = [0] #place holder token id
|
|
results.append((next_token_ids, parent_ids))
|
|
sample_idx += num_parent_seqs
|
|
return results
|
|
|
|
def _random_sample(
|
|
selected_seq_groups: List[SequenceGroupToSample],
|
|
random_samples: torch.Tensor,
|
|
) -> SampleResultType:
|
|
"""Run random sampling on a given samples.
|
|
|
|
Args:
|
|
selected_seq_groups: A list of sequence groups batched.
|
|
random_samples: (num_selected_samples,) A tensor of samples. The
|
|
length of samples could be smaller than selected_seq_groups if
|
|
seq_group.do_sample is False.
|
|
Returns:
|
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
|
same as the length of selected_seq_groups. If the corresponding
|
|
seq_group has do_sample=False, tuple contains ([], [])
|
|
"""
|
|
# Find the maximum n value of the prompt phase requests.
|
|
sample_idx = 0
|
|
results: SampleResultType = []
|
|
for seq_group in selected_seq_groups:
|
|
if not seq_group.do_sample:
|
|
results.append(([], []))
|
|
continue
|
|
|
|
seq_ids = seq_group.seq_ids
|
|
sampling_params = seq_group.sampling_params
|
|
is_prompt = seq_group.is_prompt
|
|
num_parent_seqs = len(seq_ids)
|
|
if is_prompt:
|
|
# Prompt phase.
|
|
parent_ids = [0] * sampling_params.n
|
|
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
|
next_token_ids = [0] * sampling_params.n #place holder token id
|
|
else:
|
|
# Generation phase.
|
|
parent_ids = list(range(num_parent_seqs))
|
|
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
|
next_token_ids = [0] * num_parent_seqs #place holder token id
|
|
results.append((next_token_ids, parent_ids))
|
|
sample_idx += num_parent_seqs
|
|
return results
|
|
|
|
def _sample(
|
|
probs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_tensors: SamplingTensors,
|
|
include_gpu_probs_tensor: bool,
|
|
modify_greedy_probs: bool,
|
|
) -> SampleReturnType:
|
|
"""
|
|
Args:
|
|
probs: (num_query_tokens_in_batch, num_vocab)
|
|
logprobs: (num_query_tokens_in_batch, num_vocab)
|
|
sampling_metadata: The metadata for a batch for sampling.
|
|
sampling_tensors: Tensors that include sampling related metadata.
|
|
|
|
Returns:
|
|
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
|
If sampling is skipped, it returns ([], [])
|
|
sampled_token_ids_tensor: A tensor of sampled token ids.
|
|
"""
|
|
return _sample_with_torch(
|
|
probs,
|
|
logprobs,
|
|
sampling_metadata,
|
|
sampling_tensors,
|
|
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
|
modify_greedy_probs=modify_greedy_probs,
|
|
)
|
|
|
|
def _sample_with_torch(
|
|
probs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_tensors: SamplingTensors,
|
|
include_gpu_probs_tensor: bool,
|
|
modify_greedy_probs: bool,
|
|
) -> SampleReturnType:
|
|
'''Torch-oriented _sample() implementation.
|
|
|
|
Single-step scheduling:
|
|
* Perform GPU-side sampling computation
|
|
* Immediately Pythonize sampling result
|
|
|
|
Multi-step scheduling:
|
|
* Perform GPU-side sampling computation
|
|
* Defer Pythonization & preserve GPU-side
|
|
tensors required for Pythonization
|
|
'''
|
|
|
|
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
|
t: []
|
|
for t in SamplingType
|
|
}
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
sampling_params = seq_group.sampling_params
|
|
sampling_type = sampling_params.sampling_type
|
|
categorized_seq_group_ids[sampling_type].append(i)
|
|
sample_results_dict: SampleResultsDictType = {}
|
|
sample_metadata: SampleMetadataType = {}
|
|
multinomial_samples: MultinomialSamplesType = {}
|
|
greedy_samples: Optional[torch.Tensor] = None
|
|
|
|
# Create output tensor for sampled token ids.
|
|
if include_gpu_probs_tensor:
|
|
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
|
|
VLLM_INVALID_TOKEN_ID,
|
|
dtype=torch.long,
|
|
device=logprobs.device)
|
|
else:
|
|
sampled_token_ids_tensor = None
|
|
|
|
# Counterintiutively, having two loops here is actually faster.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
for sampling_type in SamplingType:
|
|
sample_indices = categorized_sample_indices[sampling_type]
|
|
num_tokens = len(sample_indices)
|
|
if num_tokens == 0:
|
|
continue
|
|
|
|
seq_group_id = categorized_seq_group_ids[sampling_type]
|
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
|
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
|
long_sample_indices = sample_indices.long()
|
|
if sampling_type == SamplingType.GREEDY:
|
|
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
|
dim=-1)
|
|
|
|
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
|
|
|
|
if sampled_token_ids_tensor is not None:
|
|
# Store sampled tokens in output tensor.
|
|
sampled_token_ids_tensor[
|
|
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
|
|
|
if modify_greedy_probs:
|
|
# If required, modify the probabilities such that sampling from
|
|
# the modified distribution would always sample the argmax
|
|
# token id.
|
|
_modify_greedy_probs_inplace(logprobs, probs,
|
|
long_sample_indices,
|
|
greedy_samples)
|
|
|
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
|
max_n_in_batch = 1
|
|
for seq_group in seq_groups:
|
|
if seq_group.is_prompt:
|
|
sampling_params = seq_group.sampling_params
|
|
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
|
|
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
|
seq_groups)
|
|
|
|
if flashinfer_top_k_top_p_sampling is not None:
|
|
multinomial_samples[
|
|
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
|
probs[long_sample_indices],
|
|
sampling_tensors.top_ks[long_sample_indices],
|
|
sampling_tensors.top_ps[long_sample_indices],
|
|
max_n_in_batch,
|
|
seq_groups_arg,
|
|
)
|
|
else:
|
|
multinomial_samples[sampling_type] = _multinomial(
|
|
probs[long_sample_indices],
|
|
max_n_in_batch,
|
|
seq_groups=seq_groups_arg)
|
|
|
|
last_sampler.sampled_token_ids_tensor = \
|
|
multinomial_samples[sampling_type].to(torch.long)
|
|
|
|
if sampled_token_ids_tensor is not None:
|
|
# Store sampled tokens in output tensor.
|
|
sampled_token_ids_tensor[long_sample_indices] = \
|
|
multinomial_samples[sampling_type].to(torch.long)
|
|
|
|
# Encapsulate arguments for computing Pythonized sampler
|
|
# results, whether deferred or otherwise.
|
|
maybe_deferred_args = SampleResultArgsType(
|
|
sampling_metadata=sampling_metadata,
|
|
sample_metadata=sample_metadata,
|
|
multinomial_samples=multinomial_samples,
|
|
greedy_samples=greedy_samples,
|
|
sample_results_dict=sample_results_dict)
|
|
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
# GPU<->CPU sync happens here.
|
|
# This also converts the sampler output to a Python object.
|
|
# Return Pythonized sampler result & sampled token ids
|
|
return get_pythonized_sample_results(
|
|
maybe_deferred_args), sampled_token_ids_tensor
|
|
else:
|
|
# Defer sampler result Pythonization; return deferred
|
|
# Pythonization args & sampled token ids
|
|
return (
|
|
maybe_deferred_args,
|
|
sampled_token_ids_tensor,
|
|
)
|
|
|
|
|
|
def get_pythonized_sample_results(
|
|
sample_result_args: SampleResultArgsType) -> SampleResultType:
|
|
'''This function consumes GPU-side sampler results and computes
|
|
Pythonized CPU-side sampler results (GPU -> CPU sync.)
|
|
|
|
Single-step scheduling: this function is invoked at sampling-time
|
|
for immediate Pythonization.
|
|
|
|
Multi-step scheduling: Pythonization is deferred until after multiple
|
|
GPU-side steps have been completed.
|
|
|
|
Args:
|
|
sample_result_args: GPU-side inputs to the Pythonization process
|
|
|
|
Returns:
|
|
Pythonized sampler results
|
|
'''
|
|
|
|
(
|
|
sample_metadata,
|
|
sampling_metadata,
|
|
greedy_samples,
|
|
multinomial_samples,
|
|
sample_results_dict,
|
|
) = (
|
|
sample_result_args.sample_metadata,
|
|
sample_result_args.sampling_metadata,
|
|
sample_result_args.greedy_samples,
|
|
sample_result_args.multinomial_samples,
|
|
sample_result_args.sample_results_dict,
|
|
)
|
|
|
|
for sampling_type in SamplingType:
|
|
if sampling_type not in sample_metadata:
|
|
continue
|
|
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
|
if sampling_type == SamplingType.GREEDY:
|
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
|
sample_results = _random_sample(seq_groups,
|
|
multinomial_samples[sampling_type])
|
|
sample_results_dict.update(zip(seq_group_id, sample_results))
|
|
|
|
return [
|
|
sample_results_dict.get(i, ([], []))
|
|
for i in range(len(sampling_metadata.seq_groups))
|
|
]
|
|
|
|
def _build_sampler_output(
|
|
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
|
|
sampling_metadata: SamplingMetadata,
|
|
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
|
sample_logprobs: Optional[List[SampleLogprobs]],
|
|
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
|
torch.Tensor]],
|
|
skip_sampler_cpu_output: bool = False,
|
|
logits: Optional[torch.Tensor] = None
|
|
) -> SamplerOutput:
|
|
"""Construct Python objects with the output of sampling.
|
|
|
|
Args:
|
|
on_device_tensors: Tuple containing on-device tensors with the
|
|
probabilities used in sampling and the sampled token ids. This
|
|
allows post-processing without copies to CPU/serialization, e.g. in
|
|
speculative decoding rejection sampling.
|
|
"""
|
|
sampler_output: List[CompletionSequenceGroupOutput] = []
|
|
|
|
last_sampler.seq_ids = []
|
|
if skip_sampler_cpu_output:
|
|
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
|
|
deferred_sample_results_args = maybe_deferred_sample_results
|
|
else:
|
|
assert prompt_logprobs is not None
|
|
assert sample_logprobs is not None
|
|
assert not isinstance(maybe_deferred_sample_results,
|
|
SampleResultArgsType)
|
|
assert len(sampling_metadata.seq_groups) \
|
|
== len(maybe_deferred_sample_results) \
|
|
== len(prompt_logprobs) \
|
|
== len(sample_logprobs)
|
|
deferred_sample_results_args = None
|
|
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
|
maybe_deferred_sample_results,
|
|
prompt_logprobs, sample_logprobs):
|
|
seq_ids = seq_group.seq_ids
|
|
next_token_ids, parent_ids = sample_result
|
|
seq_outputs: List[SequenceOutput] = []
|
|
for parent_id, next_token_id, logprobs in zip(
|
|
parent_ids, next_token_ids, group_sample_logprobs):
|
|
seq_outputs.append(
|
|
SequenceOutput(seq_ids[parent_id], next_token_id,
|
|
logprobs))
|
|
sampler_output.append(
|
|
CompletionSequenceGroupOutput(seq_outputs,
|
|
group_prompt_logprobs))
|
|
if len(seq_outputs) > 0:
|
|
last_sampler.seq_ids.append(seq_outputs[0].parent_seq_id)
|
|
|
|
# If not specified, store None values in SamplerOutput.
|
|
if on_device_tensors is not None:
|
|
(sampled_token_probs, logprobs_tensor,
|
|
sampled_token_ids) = on_device_tensors
|
|
else:
|
|
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
|
None)
|
|
|
|
return SamplerOutput(
|
|
outputs=sampler_output,
|
|
sampled_token_probs=sampled_token_probs,
|
|
sampled_token_ids=sampled_token_ids,
|
|
logprobs=logprobs_tensor,
|
|
deferred_sample_results_args=deferred_sample_results_args,
|
|
logits=logits) |