init src 0.9.2
This commit is contained in:
500
vllm/zero_overhead/sampler.py
Normal file
500
vllm/zero_overhead/sampler.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from importlib.util import find_spec
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.sampler import MaybeDeferredSampleResultType, MultinomialSamplesType, SampleMetadataType, \
|
||||
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
|
||||
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, \
|
||||
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, PromptLogprobs, SampleLogprobs, SequenceOutput
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
# yapf: disable
|
||||
from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||
# yapf: enable
|
||||
else:
|
||||
flashinfer_top_k_top_p_sampling = None
|
||||
|
||||
class SampleRecorder:
|
||||
def __init__(self):
|
||||
self.seq_ids:torch.Tensor = None
|
||||
self.sampled_token_ids_tensor:torch.Tensor = None
|
||||
|
||||
last_sampler = None
|
||||
|
||||
def get_last_sampler():
|
||||
return last_sampler
|
||||
|
||||
class ZeroOverheadSampler(Sampler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Pythonize sampling result & logprobs tensor
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation & compute
|
||||
GPU-side logprobs tensor
|
||||
* Defer Pythonization of sampling result & logprobs
|
||||
tensor
|
||||
* Encapsulate arguments required for deferred Pythonization
|
||||
in the :class:`SamplerOutput` structure
|
||||
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
global last_sampler
|
||||
last_sampler = SampleRecorder()
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
sampling_tensors.top_ks)
|
||||
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
# Since we will defer sampler result Pythonization,
|
||||
# preserve GPU-side tensors in support of later
|
||||
# deferred pythonization of logprobs
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
# Since Pythonization has already happened, don't preserve
|
||||
# GPU-side tensors.
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
prompt_logprobs, sample_logprobs = get_logprobs(
|
||||
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
maybe_deferred_sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
|
||||
logits=logits)
|
||||
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
samples: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run greedy sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
samples: (num_selected_samples,) A tensor of samples. The length of
|
||||
samples could be smaller than selected_seq_groups if
|
||||
seq_group.do_sample is False.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
sample_idx = 0
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
num_parent_seqs = len(seq_ids)
|
||||
assert num_parent_seqs == 1, (
|
||||
"Greedy sampling should have only one seq.")
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
||||
next_token_ids = [0] #place holder token id
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
return results
|
||||
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
random_samples: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run random sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
random_samples: (num_selected_samples,) A tensor of samples. The
|
||||
length of samples could be smaller than selected_seq_groups if
|
||||
seq_group.do_sample is False.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
# Find the maximum n value of the prompt phase requests.
|
||||
sample_idx = 0
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
is_prompt = seq_group.is_prompt
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
parent_ids = [0] * sampling_params.n
|
||||
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
||||
next_token_ids = [0] * sampling_params.n #place holder token id
|
||||
else:
|
||||
# Generation phase.
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
assert num_parent_seqs == 1 # not support muti seqences in seqence group
|
||||
next_token_ids = [0] * num_parent_seqs #place holder token id
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
return results
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> SampleReturnType:
|
||||
"""
|
||||
Args:
|
||||
probs: (num_query_tokens_in_batch, num_vocab)
|
||||
logprobs: (num_query_tokens_in_batch, num_vocab)
|
||||
sampling_metadata: The metadata for a batch for sampling.
|
||||
sampling_tensors: Tensors that include sampling related metadata.
|
||||
|
||||
Returns:
|
||||
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
||||
If sampling is skipped, it returns ([], [])
|
||||
sampled_token_ids_tensor: A tensor of sampled token ids.
|
||||
"""
|
||||
return _sample_with_torch(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
)
|
||||
|
||||
def _sample_with_torch(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> SampleReturnType:
|
||||
'''Torch-oriented _sample() implementation.
|
||||
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation
|
||||
* Immediately Pythonize sampling result
|
||||
|
||||
Multi-step scheduling:
|
||||
* Perform GPU-side sampling computation
|
||||
* Defer Pythonization & preserve GPU-side
|
||||
tensors required for Pythonization
|
||||
'''
|
||||
|
||||
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
sample_results_dict: SampleResultsDictType = {}
|
||||
sample_metadata: SampleMetadataType = {}
|
||||
multinomial_samples: MultinomialSamplesType = {}
|
||||
greedy_samples: Optional[torch.Tensor] = None
|
||||
|
||||
# Create output tensor for sampled token ids.
|
||||
if include_gpu_probs_tensor:
|
||||
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
|
||||
VLLM_INVALID_TOKEN_ID,
|
||||
dtype=torch.long,
|
||||
device=logprobs.device)
|
||||
else:
|
||||
sampled_token_ids_tensor = None
|
||||
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
||||
long_sample_indices = sample_indices.long()
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
||||
dim=-1)
|
||||
|
||||
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[
|
||||
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
||||
|
||||
if modify_greedy_probs:
|
||||
# If required, modify the probabilities such that sampling from
|
||||
# the modified distribution would always sample the argmax
|
||||
# token id.
|
||||
_modify_greedy_probs_inplace(logprobs, probs,
|
||||
long_sample_indices,
|
||||
greedy_samples)
|
||||
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
max_n_in_batch = 1
|
||||
for seq_group in seq_groups:
|
||||
if seq_group.is_prompt:
|
||||
sampling_params = seq_group.sampling_params
|
||||
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
|
||||
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
||||
seq_groups)
|
||||
|
||||
if flashinfer_top_k_top_p_sampling is not None:
|
||||
multinomial_samples[
|
||||
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs[long_sample_indices],
|
||||
sampling_tensors.top_ks[long_sample_indices],
|
||||
sampling_tensors.top_ps[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups_arg,
|
||||
)
|
||||
else:
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups=seq_groups_arg)
|
||||
|
||||
last_sampler.sampled_token_ids_tensor = \
|
||||
multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[long_sample_indices] = \
|
||||
multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
# Encapsulate arguments for computing Pythonized sampler
|
||||
# results, whether deferred or otherwise.
|
||||
maybe_deferred_args = SampleResultArgsType(
|
||||
sampling_metadata=sampling_metadata,
|
||||
sample_metadata=sample_metadata,
|
||||
multinomial_samples=multinomial_samples,
|
||||
greedy_samples=greedy_samples,
|
||||
sample_results_dict=sample_results_dict)
|
||||
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
# GPU<->CPU sync happens here.
|
||||
# This also converts the sampler output to a Python object.
|
||||
# Return Pythonized sampler result & sampled token ids
|
||||
return get_pythonized_sample_results(
|
||||
maybe_deferred_args), sampled_token_ids_tensor
|
||||
else:
|
||||
# Defer sampler result Pythonization; return deferred
|
||||
# Pythonization args & sampled token ids
|
||||
return (
|
||||
maybe_deferred_args,
|
||||
sampled_token_ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
def get_pythonized_sample_results(
|
||||
sample_result_args: SampleResultArgsType) -> SampleResultType:
|
||||
'''This function consumes GPU-side sampler results and computes
|
||||
Pythonized CPU-side sampler results (GPU -> CPU sync.)
|
||||
|
||||
Single-step scheduling: this function is invoked at sampling-time
|
||||
for immediate Pythonization.
|
||||
|
||||
Multi-step scheduling: Pythonization is deferred until after multiple
|
||||
GPU-side steps have been completed.
|
||||
|
||||
Args:
|
||||
sample_result_args: GPU-side inputs to the Pythonization process
|
||||
|
||||
Returns:
|
||||
Pythonized sampler results
|
||||
'''
|
||||
|
||||
(
|
||||
sample_metadata,
|
||||
sampling_metadata,
|
||||
greedy_samples,
|
||||
multinomial_samples,
|
||||
sample_results_dict,
|
||||
) = (
|
||||
sample_result_args.sample_metadata,
|
||||
sample_result_args.sampling_metadata,
|
||||
sample_result_args.greedy_samples,
|
||||
sample_result_args.multinomial_samples,
|
||||
sample_result_args.sample_results_dict,
|
||||
)
|
||||
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(seq_groups,
|
||||
multinomial_samples[sampling_type])
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
return [
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
|
||||
def _build_sampler_output(
|
||||
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
||||
sample_logprobs: Optional[List[SampleLogprobs]],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]],
|
||||
skip_sampler_cpu_output: bool = False,
|
||||
logits: Optional[torch.Tensor] = None
|
||||
) -> SamplerOutput:
|
||||
"""Construct Python objects with the output of sampling.
|
||||
|
||||
Args:
|
||||
on_device_tensors: Tuple containing on-device tensors with the
|
||||
probabilities used in sampling and the sampled token ids. This
|
||||
allows post-processing without copies to CPU/serialization, e.g. in
|
||||
speculative decoding rejection sampling.
|
||||
"""
|
||||
sampler_output: List[CompletionSequenceGroupOutput] = []
|
||||
|
||||
last_sampler.seq_ids = []
|
||||
if skip_sampler_cpu_output:
|
||||
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
|
||||
deferred_sample_results_args = maybe_deferred_sample_results
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert sample_logprobs is not None
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
assert len(sampling_metadata.seq_groups) \
|
||||
== len(maybe_deferred_sample_results) \
|
||||
== len(prompt_logprobs) \
|
||||
== len(sample_logprobs)
|
||||
deferred_sample_results_args = None
|
||||
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
maybe_deferred_sample_results,
|
||||
prompt_logprobs, sample_logprobs):
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
for parent_id, next_token_id, logprobs in zip(
|
||||
parent_ids, next_token_ids, group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
logprobs))
|
||||
sampler_output.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs,
|
||||
group_prompt_logprobs))
|
||||
if len(seq_outputs) > 0:
|
||||
last_sampler.seq_ids.append(seq_outputs[0].parent_seq_id)
|
||||
|
||||
# If not specified, store None values in SamplerOutput.
|
||||
if on_device_tensors is not None:
|
||||
(sampled_token_probs, logprobs_tensor,
|
||||
sampled_token_ids) = on_device_tensors
|
||||
else:
|
||||
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
||||
None)
|
||||
|
||||
return SamplerOutput(
|
||||
outputs=sampler_output,
|
||||
sampled_token_probs=sampled_token_probs,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs_tensor,
|
||||
deferred_sample_results_args=deferred_sample_results_args,
|
||||
logits=logits)
|
||||
Reference in New Issue
Block a user