# SPDX-License-Identifier: Apache-2.0 from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceGroupMetadata) from vllm.utils import (PyObjectCache, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad) from vllm.model_executor.sampling_metadata import SamplingTensors, SamplingMetadataCache, _prepare_seq_groups, SamplingMetadata, _SAMPLING_EPS @staticmethod def SamplingMetadata_prepare( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], query_lens: List[int], device: str, pin_memory: bool, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, ) -> "SamplingMetadata": ( seq_groups, selected_token_indices, categorized_sample_indices, num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, device, generators, cache) selected_token_indices = async_tensor_h2d( selected_token_indices, dtype=torch.int32, #use int32 instead of long target_device=device, pin_memory=pin_memory, ) categorized_sample_indices = { t: async_tensor_h2d( seq_ids, dtype=torch.int, target_device=device, pin_memory=pin_memory, ) for t, seq_ids in categorized_sample_indices.items() } sampling_metadata = SamplingMetadata( seq_groups=seq_groups, selected_token_indices=selected_token_indices, categorized_sample_indices=categorized_sample_indices, num_prompts=num_prompts, ) return sampling_metadata @classmethod def SamplingTensors_from_lists( cls, temperatures: List[float], top_ps: List[float], top_ks: List[int], min_ps: List[float], presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], prompt_tokens: List[array], output_tokens: List[array], vocab_size: int, device: torch.device, dtype: torch.dtype, ) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() do_penalties = prompt_tokens or output_tokens if do_penalties: prompt_t = make_tensor_with_pad( prompt_tokens, vocab_size, device="cpu", dtype=torch.int64, pin_memory=pin_memory, ) output_t = make_tensor_with_pad( output_tokens, vocab_size, device="cpu", dtype=torch.int64, pin_memory=pin_memory, ) else: empty_tensor = torch.empty(0, device=device, dtype=torch.long) prompt_t = empty_tensor output_t = empty_tensor temperatures_t = torch.tensor( temperatures, device="cpu", dtype=torch.float32, pin_memory=pin_memory, ) top_ps_t = torch.tensor( top_ps, device="cpu", dtype=torch.float32, pin_memory=pin_memory, ) min_ps_t = torch.tensor( min_ps, device="cpu", dtype=dtype, pin_memory=pin_memory, ) presence_penalties_t = torch.tensor( presence_penalties, device="cpu", dtype=dtype, pin_memory=pin_memory, ) frequency_penalties_t = torch.tensor( frequency_penalties, device="cpu", dtype=dtype, pin_memory=pin_memory, ) repetition_penalties_t = torch.tensor( repetition_penalties, device="cpu", dtype=dtype, pin_memory=pin_memory, ) top_ks_t = torch.tensor( top_ks, device="cpu", dtype=torch.int, pin_memory=pin_memory, ) # Because the memory is pinned, we can do non-blocking return cls( temperatures=temperatures_t, top_ps=top_ps_t, top_ks=top_ks_t, min_ps=min_ps_t, presence_penalties=presence_penalties_t, frequency_penalties=frequency_penalties_t, repetition_penalties=repetition_penalties_t, prompt_tokens=prompt_t, output_tokens=output_t, ) @classmethod def SamplingMetadata_from_sampling_metadata( cls, sampling_metadata: "SamplingMetadata", vocab_size: int, device: torch.device, dtype: torch.dtype, ) -> Tuple["SamplingTensors", bool, bool, bool]: prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] temperatures: List[float] = [] top_ps: List[float] = [] min_ps: List[float] = [] presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] do_penalties = False do_top_p_top_k = False do_min_p = False assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params temperature = sampling_params.temperature p = sampling_params.presence_penalty f = sampling_params.frequency_penalty r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) # top_k = vocab_size if top_k == -1 else top_k # FIXME: fix top_k to avoid odsp bug currently top_k = 40 if temperature < _SAMPLING_EPS: # NOTE: Zero temperature means deterministic sampling # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. temperature = 1.0 if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size): do_top_p_top_k = True if not do_min_p and min_p > _SAMPLING_EPS: do_min_p = True if not do_penalties and (abs(p) >= _SAMPLING_EPS or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True is_prompt = seq_group.is_prompt if is_prompt and sampling_params.prompt_logprobs is not None: # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len assert query_len is not None prefill_len = len(seq_group.prompt_logprob_indices) temperatures += [temperature] * prefill_len top_ps += [top_p] * prefill_len top_ks += [top_k] * prefill_len min_ps += [min_p] * prefill_len presence_penalties += [0] * prefill_len frequency_penalties += [0] * prefill_len repetition_penalties += [1] * prefill_len if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) assert sample_lens >= len(seq_ids) temperatures += [temperature] * sample_lens top_ps += [top_p] * sample_lens top_ks += [top_k] * sample_lens min_ps += [min_p] * sample_lens presence_penalties += [p] * sample_lens frequency_penalties += [f] * sample_lens repetition_penalties += [r] * sample_lens if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) prompt_tokens.extend( array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len)) output_tokens.extend( array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len)) if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] prompt_tokens.append(seq_data.prompt_token_ids_array) output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, frequency_penalties, repetition_penalties, prompt_tokens, output_tokens, vocab_size, device, dtype, ) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)