268 lines
8.8 KiB
Python
268 lines
8.8 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
|
||
|
|
from array import array
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||
|
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
||
|
|
SequenceGroupMetadata)
|
||
|
|
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
||
|
|
is_pin_memory_available, make_tensor_with_pad)
|
||
|
|
|
||
|
|
|
||
|
|
from vllm.model_executor.sampling_metadata import SamplingTensors, SamplingMetadataCache, _prepare_seq_groups, SamplingMetadata, _SAMPLING_EPS
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def SamplingMetadata_prepare(
|
||
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||
|
|
seq_lens: List[int],
|
||
|
|
query_lens: List[int],
|
||
|
|
device: str,
|
||
|
|
pin_memory: bool,
|
||
|
|
generators: Optional[Dict[str, torch.Generator]] = None,
|
||
|
|
cache: Optional[SamplingMetadataCache] = None,
|
||
|
|
) -> "SamplingMetadata":
|
||
|
|
(
|
||
|
|
seq_groups,
|
||
|
|
selected_token_indices,
|
||
|
|
categorized_sample_indices,
|
||
|
|
num_prompts,
|
||
|
|
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
||
|
|
device, generators, cache)
|
||
|
|
selected_token_indices = async_tensor_h2d(
|
||
|
|
selected_token_indices,
|
||
|
|
dtype=torch.int32, #use int32 instead of long
|
||
|
|
target_device=device,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
categorized_sample_indices = {
|
||
|
|
t:
|
||
|
|
async_tensor_h2d(
|
||
|
|
seq_ids,
|
||
|
|
dtype=torch.int,
|
||
|
|
target_device=device,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
for t, seq_ids in categorized_sample_indices.items()
|
||
|
|
}
|
||
|
|
|
||
|
|
sampling_metadata = SamplingMetadata(
|
||
|
|
seq_groups=seq_groups,
|
||
|
|
selected_token_indices=selected_token_indices,
|
||
|
|
categorized_sample_indices=categorized_sample_indices,
|
||
|
|
num_prompts=num_prompts,
|
||
|
|
)
|
||
|
|
return sampling_metadata
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def SamplingTensors_from_lists(
|
||
|
|
cls,
|
||
|
|
temperatures: List[float],
|
||
|
|
top_ps: List[float],
|
||
|
|
top_ks: List[int],
|
||
|
|
min_ps: List[float],
|
||
|
|
presence_penalties: List[float],
|
||
|
|
frequency_penalties: List[float],
|
||
|
|
repetition_penalties: List[float],
|
||
|
|
prompt_tokens: List[array],
|
||
|
|
output_tokens: List[array],
|
||
|
|
vocab_size: int,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
) -> "SamplingTensors":
|
||
|
|
# Note that the performance will be very bad without
|
||
|
|
# pinned memory.
|
||
|
|
pin_memory = is_pin_memory_available()
|
||
|
|
|
||
|
|
do_penalties = prompt_tokens or output_tokens
|
||
|
|
|
||
|
|
if do_penalties:
|
||
|
|
prompt_t = make_tensor_with_pad(
|
||
|
|
prompt_tokens,
|
||
|
|
vocab_size,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.int64,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
output_t = make_tensor_with_pad(
|
||
|
|
output_tokens,
|
||
|
|
vocab_size,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.int64,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
||
|
|
prompt_t = empty_tensor
|
||
|
|
output_t = empty_tensor
|
||
|
|
|
||
|
|
temperatures_t = torch.tensor(
|
||
|
|
temperatures,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.float32,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
top_ps_t = torch.tensor(
|
||
|
|
top_ps,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.float32,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
min_ps_t = torch.tensor(
|
||
|
|
min_ps,
|
||
|
|
device="cpu",
|
||
|
|
dtype=dtype,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
presence_penalties_t = torch.tensor(
|
||
|
|
presence_penalties,
|
||
|
|
device="cpu",
|
||
|
|
dtype=dtype,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
frequency_penalties_t = torch.tensor(
|
||
|
|
frequency_penalties,
|
||
|
|
device="cpu",
|
||
|
|
dtype=dtype,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
repetition_penalties_t = torch.tensor(
|
||
|
|
repetition_penalties,
|
||
|
|
device="cpu",
|
||
|
|
dtype=dtype,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
top_ks_t = torch.tensor(
|
||
|
|
top_ks,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.int,
|
||
|
|
pin_memory=pin_memory,
|
||
|
|
)
|
||
|
|
# Because the memory is pinned, we can do non-blocking
|
||
|
|
return cls(
|
||
|
|
temperatures=temperatures_t,
|
||
|
|
top_ps=top_ps_t,
|
||
|
|
top_ks=top_ks_t,
|
||
|
|
min_ps=min_ps_t,
|
||
|
|
presence_penalties=presence_penalties_t,
|
||
|
|
frequency_penalties=frequency_penalties_t,
|
||
|
|
repetition_penalties=repetition_penalties_t,
|
||
|
|
prompt_tokens=prompt_t,
|
||
|
|
output_tokens=output_t,
|
||
|
|
)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def SamplingMetadata_from_sampling_metadata(
|
||
|
|
cls,
|
||
|
|
sampling_metadata: "SamplingMetadata",
|
||
|
|
vocab_size: int,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
||
|
|
prompt_tokens: List[array] = []
|
||
|
|
output_tokens: List[array] = []
|
||
|
|
top_ks: List[int] = []
|
||
|
|
temperatures: List[float] = []
|
||
|
|
top_ps: List[float] = []
|
||
|
|
min_ps: List[float] = []
|
||
|
|
presence_penalties: List[float] = []
|
||
|
|
frequency_penalties: List[float] = []
|
||
|
|
repetition_penalties: List[float] = []
|
||
|
|
do_penalties = False
|
||
|
|
do_top_p_top_k = False
|
||
|
|
do_min_p = False
|
||
|
|
|
||
|
|
assert sampling_metadata.seq_groups is not None
|
||
|
|
for seq_group in sampling_metadata.seq_groups:
|
||
|
|
seq_ids = seq_group.seq_ids
|
||
|
|
sampling_params = seq_group.sampling_params
|
||
|
|
temperature = sampling_params.temperature
|
||
|
|
p = sampling_params.presence_penalty
|
||
|
|
f = sampling_params.frequency_penalty
|
||
|
|
r = sampling_params.repetition_penalty
|
||
|
|
top_p = sampling_params.top_p
|
||
|
|
min_p = sampling_params.min_p
|
||
|
|
|
||
|
|
# k should not be greater than the vocab size.
|
||
|
|
top_k = min(sampling_params.top_k, vocab_size)
|
||
|
|
# top_k = vocab_size if top_k == -1 else top_k
|
||
|
|
# FIXME: fix top_k to avoid odsp bug currently
|
||
|
|
top_k = 40
|
||
|
|
if temperature < _SAMPLING_EPS:
|
||
|
|
# NOTE: Zero temperature means deterministic sampling
|
||
|
|
# (i.e., greedy sampling or beam search).
|
||
|
|
# Set the temperature to 1 to avoid division by zero.
|
||
|
|
temperature = 1.0
|
||
|
|
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
||
|
|
or top_k != vocab_size):
|
||
|
|
do_top_p_top_k = True
|
||
|
|
if not do_min_p and min_p > _SAMPLING_EPS:
|
||
|
|
do_min_p = True
|
||
|
|
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
||
|
|
or abs(f) >= _SAMPLING_EPS
|
||
|
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||
|
|
do_penalties = True
|
||
|
|
|
||
|
|
is_prompt = seq_group.is_prompt
|
||
|
|
if is_prompt and sampling_params.prompt_logprobs is not None:
|
||
|
|
# For tokens in the prompt that we only need to get
|
||
|
|
# their logprobs
|
||
|
|
query_len = seq_group.query_len
|
||
|
|
assert query_len is not None
|
||
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
||
|
|
temperatures += [temperature] * prefill_len
|
||
|
|
top_ps += [top_p] * prefill_len
|
||
|
|
top_ks += [top_k] * prefill_len
|
||
|
|
min_ps += [min_p] * prefill_len
|
||
|
|
presence_penalties += [0] * prefill_len
|
||
|
|
frequency_penalties += [0] * prefill_len
|
||
|
|
repetition_penalties += [1] * prefill_len
|
||
|
|
|
||
|
|
if seq_group.do_sample:
|
||
|
|
sample_lens = len(seq_group.sample_indices)
|
||
|
|
assert sample_lens >= len(seq_ids)
|
||
|
|
temperatures += [temperature] * sample_lens
|
||
|
|
top_ps += [top_p] * sample_lens
|
||
|
|
top_ks += [top_k] * sample_lens
|
||
|
|
min_ps += [min_p] * sample_lens
|
||
|
|
presence_penalties += [p] * sample_lens
|
||
|
|
frequency_penalties += [f] * sample_lens
|
||
|
|
repetition_penalties += [r] * sample_lens
|
||
|
|
|
||
|
|
if do_penalties:
|
||
|
|
for seq_group in sampling_metadata.seq_groups:
|
||
|
|
seq_ids = seq_group.seq_ids
|
||
|
|
sampling_params = seq_group.sampling_params
|
||
|
|
if (seq_group.is_prompt
|
||
|
|
and sampling_params.prompt_logprobs is not None):
|
||
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
||
|
|
prompt_tokens.extend(
|
||
|
|
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||
|
|
for _ in range(prefill_len))
|
||
|
|
output_tokens.extend(
|
||
|
|
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||
|
|
for _ in range(prefill_len))
|
||
|
|
if seq_group.do_sample:
|
||
|
|
for seq_id in seq_ids:
|
||
|
|
seq_data = seq_group.seq_data[seq_id]
|
||
|
|
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
||
|
|
output_tokens.append(seq_data.output_token_ids_array)
|
||
|
|
|
||
|
|
sampling_tensors = SamplingTensors.from_lists(
|
||
|
|
temperatures,
|
||
|
|
top_ps,
|
||
|
|
top_ks,
|
||
|
|
min_ps,
|
||
|
|
presence_penalties,
|
||
|
|
frequency_penalties,
|
||
|
|
repetition_penalties,
|
||
|
|
prompt_tokens,
|
||
|
|
output_tokens,
|
||
|
|
vocab_size,
|
||
|
|
device,
|
||
|
|
dtype,
|
||
|
|
)
|
||
|
|
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|