# SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" import itertools import warnings from dataclasses import dataclass from importlib.util import find_spec from math import inf from typing import Dict, Iterator, List, Optional, Tuple, Union import msgspec import torch import torch.nn as nn import vllm.envs as envs from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.model_executor.layers.sampler import (SamplerOutput, _apply_min_tokens_penalty, _apply_top_k_top_p, _apply_min_p, _sample, SampleResultArgsType, get_logprobs, _build_sampler_output, SampleReturnType, SampleResultsDictType, SampleMetadataType, MultinomialSamplesType, _modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, _multinomial, get_pythonized_sample_results, ) from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER as use_ds3_sampler from vllm_vacc.vllm.model_executor.models.vars import USE_DS3_SAMPLER_OP as use_ds3_sampler_op if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling # yapf: disable from flashinfer.sampling import ( top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) # yapf: enable else: flashinfer_top_k_top_p_sampling = None class SamplerOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. This data structure implements methods, so it can be used like a list, but also has optional fields for device tensors. """ outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional[torch.Tensor] = None # On-device tensor containing the logprobs of each token. logprobs: Optional["torch.Tensor"] = None # Holds either (1) the pythonized sampler result (single-step scheduling) # or (2) what will be arguments for later deferred pythonization of the # sampler result (muliti-step scheduling) deferred_sample_results_args: Optional[SampleResultArgsType] = None # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None # CPU tensor containing the sampled token ids. Used during multi-step to # return the sampled token ids from last rank to AsyncLLMEngine to be # 'broadcasted' to all other PP ranks for next step. sampled_token_ids_cpu: Optional[torch.Tensor] = None # On-device tensor containing the sampled token embeddings (embeddings # corresponding to the sampled token ids). Used when prompt embeddings are # specified in lieu of prompt token ids or text. sampled_token_embeds: Optional[torch.Tensor] = None # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None # Optional prefill hidden states from the model # (used for models like EAGLE). prefill_hidden_states: Optional[torch.Tensor] = None # Time taken in the forward pass for this across all workers model_forward_time: Optional[float] = None # Time taken in the model execute function. This will include model forward, # block/sync across workers, cpu-gpu sync time and sampling time. model_execute_time: Optional[float] = None def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: return self.outputs[idx] def __setitem__(self, idx: int, value): self.outputs[idx] = value def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: return iter(self.outputs) def __len__(self): return len(self.outputs) def __eq__(self, other: object): return isinstance(other, self.__class__) and self.outputs == other.outputs def __repr__(self) -> str: """Show the shape of a tensor instead of its values to reduce noise. """ sampled_token_probs_repr = ("None" if self.sampled_token_probs is None else self.sampled_token_probs.shape) sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else self.sampled_token_ids.shape) return ( f"SamplerOutput(outputs={self.outputs}, " f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_ids={sampled_token_ids_repr},") def Sampler_forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Single-step scheduling: * Perform GPU-side sampling computation & compute GPU-side logprobs tensor * Pythonize sampling result & logprobs tensor Multi-step scheduling: * Perform GPU-side sampling computation & compute GPU-side logprobs tensor * Defer Pythonization of sampling result & logprobs tensor * Encapsulate arguments required for deferred Pythonization in the :class:`SamplerOutput` structure Args: logits: (num_tokens, vocab_size). sampling_metadata: Metadata for sampling. """ assert logits is not None # print(f'Sampler_forward all_greedy={all_greedy}') # Prepare sampling tensors with pinned memory to avoid blocking. if not sampling_metadata.reuse_sampling_tensors: self._init_sampling_tensors(logits, sampling_metadata) elif self._do_penalties: # In this case, the sampling tensors logic depends on # "output_tokens" of a sequence. As a result, we cannot # reuse sampling tensors, since "output_tokens" changes # between decode runs. self._init_sampling_tensors(logits, sampling_metadata) assert self._sampling_tensors is not None sampling_tensors = self._sampling_tensors do_penalties = self._do_penalties do_top_p_top_k = self._do_top_p_top_k do_min_p = self._do_min_p is_greedy = (len(sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]) == logits.shape[0]) is_random = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM]) == logits.shape[0]) is_random_seed = (len(sampling_metadata.categorized_sample_indices[SamplingType.RANDOM_SEED]) == logits.shape[0]) max_n_in_batch = sampling_metadata.seq_groups[0].sampling_params.n generator = sampling_metadata.seq_groups[0].generator min_tokens = sampling_metadata.seq_groups[0].sampling_params.min_tokens # print("use_ds3_sampler ", use_ds3_sampler) if use_ds3_sampler == True and (is_greedy == True or ((is_random == True or is_random_seed == True) \ and do_penalties == False \ and flashinfer_top_k_top_p_sampling is None \ and min_tokens <= 0 \ and do_min_p == False \ and max_n_in_batch == 1 \ # and self._should_modify_greedy_probs_inplace == False # and self.include_gpu_probs_tensor == False )): sampling_type = SamplingType.GREEDY sample_metadata: SampleMetadataType = {} multinomial_samples: MultinomialSamplesType = {} greedy_samples: Optional[torch.Tensor] = None multinomial_out: Optional[torch.Tensor] = None vacc_device = logits.device # Create output tensor for sampled token ids. if self.include_gpu_probs_tensor: sampled_token_ids_tensor = torch.full((logits.shape[0], 1), VLLM_INVALID_TOKEN_ID, dtype=torch.long, device=vacc_device) probs_out = torch.empty_like(logits) logprobs_out = torch.empty_like(logits) else: probs_out = None logprobs_out = None sampled_token_ids_tensor = None if is_greedy == True: greedy_samples, _ = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 0) sampling_type = SamplingType.GREEDY if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. sampled_token_ids_tensor = greedy_samples.unsqueeze(-1).to(torch.long) if probs_out is not None: # probs_out = torch.softmax(logits.to(torch.float), dim=-1, dtype=torch.float).to(logits) probs_out = torch.softmax(logits, dim=-1) if self._should_modify_greedy_probs_inplace == True: sample_indices = (sampling_metadata.categorized_sample_indices[SamplingType.GREEDY]).long() probs_out[sample_indices, :] = 0 probs_out[sample_indices, greedy_samples] = 1.0 elif is_random == True and do_top_p_top_k == True: if use_ds3_sampler_op: logits = logits.to(torch.float) multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 2) multinomial_out = multinomial_out.view(-1, max_n_in_batch) else: logits = logits.to(torch.float) logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1)) logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs_out = probs # multinomial_out = torch.multinomial(probs, 1) q = torch.empty_like(probs) q.exponential_() multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch) sampling_type = SamplingType.RANDOM elif is_random_seed == True and generator is not None and do_top_p_top_k == True: if use_ds3_sampler_op: # print("is_random_seed ", is_random_seed) logits = logits.to(torch.float) multinomial_out, probs_out = torch.vacc.ds3_sampler(logits, sampling_tensors.top_ps, sampling_tensors.top_ks, sampling_tensors.temperatures, 1, generator) multinomial_out = multinomial_out.view(-1, max_n_in_batch) else: logits = logits.to(torch.float) logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1)) logits = torch.vacc.topk_topp(logits, sampling_tensors.top_ps, sampling_tensors.top_ks).to(torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs_out = probs # torch.manual_seed(sampling_metadata.seq_groups[0].sampling_params.seed) # multinomial_out = torch.multinomial(probs, 1) q = torch.empty_like(probs) q.exponential_(generator=generator) multinomial_out = probs.div_(q).argmax(dim=1).view(-1, max_n_in_batch) sampling_type = SamplingType.RANDOM_SEED multinomial_samples[sampling_type] = multinomial_out if sampled_token_ids_tensor is not None: if(sampling_type != SamplingType.GREEDY): # Store sampled tokens in output tensor. sampled_token_ids_tensor = multinomial_samples[sampling_type].to(torch.long) categorized_seq_group_ids: Dict[SamplingType, List[int]] = { t: [] for t in SamplingType } for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) seq_group_id = categorized_seq_group_ids[sampling_type] seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] sample_metadata[sampling_type] = (seq_group_id, seq_groups) sample_results_dict: SampleResultsDictType = {} maybe_deferred_args = SampleResultArgsType( sampling_metadata=sampling_metadata, sample_metadata=sample_metadata, multinomial_samples=multinomial_samples, greedy_samples=greedy_samples, # beam_search_logprobs=None, sample_results_dict=sample_results_dict) if not sampling_metadata.skip_sampler_cpu_output: # GPU<->CPU sync happens here. # This also converts the sampler output to a Python object. # Return Pythonized sampler result & sampled token ids maybe_deferred_sample_results, maybe_sampled_tokens_tensor = get_pythonized_sample_results( maybe_deferred_args), sampled_token_ids_tensor else: # Defer sampler result Pythonization; return deferred # Pythonization args & sampled token ids maybe_deferred_sample_results, maybe_sampled_tokens_tensor = ( maybe_deferred_args, sampled_token_ids_tensor, ) if self.include_gpu_probs_tensor: on_device_tensors = (probs_out, logprobs_out, maybe_sampled_tokens_tensor) else: on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: # Pythonize logprobs now (GPU -> CPU); do not defer. assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) logprobs = logits prompt_logprobs, sample_logprobs = get_logprobs( logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, on_device_tensors=on_device_tensors, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. # if do_penalties: # logits = apply_penalties(logits, sampling_tensors.prompt_tokens, # sampling_tensors.output_tokens, # sampling_tensors.presence_penalties.to(logits.device), # sampling_tensors.frequency_penalties.to(logits.device), # sampling_tensors.repetition_penalties.to(logits.device)) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits = logits.to(torch.float) logits.div_(sampling_tensors.temperatures.to(logits.device).to(logits.dtype).unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps.to(logits.device), sampling_tensors.top_ks.to(logits.device)) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( probs, logprobs, sampling_metadata, sampling_tensors, include_gpu_probs_tensor=self.include_gpu_probs_tensor, modify_greedy_probs=self._should_modify_greedy_probs_inplace, ) if self.include_gpu_probs_tensor: # Since we will defer sampler result Pythonization, # preserve GPU-side tensors in support of later # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: # Since Pythonization has already happened, don't preserve # GPU-side tensors. on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: # Pythonize logprobs now (GPU -> CPU); do not defer. assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) prompt_logprobs, sample_logprobs = get_logprobs( logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, on_device_tensors=on_device_tensors, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) def rejection_forward( self, target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, seeded_seqs: Optional[Dict[int, torch.Generator]] = None, ) -> torch.Tensor: if seeded_seqs is None: out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 1) else: out, index = torch.vacc.rejection_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs, draft_token_ids, 0, seeded_seqs[0]) return out class Sampler(nn.Module): def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Single-step scheduling: * Perform GPU-side sampling computation & compute GPU-side logprobs tensor * Pythonize sampling result & logprobs tensor Multi-step scheduling: * Perform GPU-side sampling computation & compute GPU-side logprobs tensor * Defer Pythonization of sampling result & logprobs tensor * Encapsulate arguments required for deferred Pythonization in the :class:`SamplerOutput` structure Args: logits: (num_tokens, vocab_size). sampling_metadata: Metadata for sampling. """ assert logits is not None _, vocab_size = logits.shape # Prepare sampling tensors with pinned memory to avoid blocking. if not sampling_metadata.reuse_sampling_tensors: self._init_sampling_tensors(logits, sampling_metadata) elif self._do_penalties: # In this case, the sampling tensors logic depends on # "output_tokens" of a sequence. As a result, we cannot # reuse sampling tensors, since "output_tokens" changes # between decode runs. self._init_sampling_tensors(logits, sampling_metadata) assert self._sampling_tensors is not None sampling_tensors = self._sampling_tensors do_penalties = self._do_penalties do_top_p_top_k = self._do_top_p_top_k do_min_p = self._do_min_p logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: logits = apply_penalties(logits, sampling_tensors.prompt_tokens, sampling_tensors.output_tokens, sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits = logits.to(torch.float) # print("tempratures is:", temperatures) logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1).to(logits.device)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( probs, logprobs, sampling_metadata, sampling_tensors, include_gpu_probs_tensor=self.include_gpu_probs_tensor, modify_greedy_probs=self._should_modify_greedy_probs_inplace, ) if self.include_gpu_probs_tensor: # Since we will defer sampler result Pythonization, # preserve GPU-side tensors in support of later # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: # Since Pythonization has already happened, don't preserve # GPU-side tensors. on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: # Pythonize logprobs now (GPU -> CPU); do not defer. assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) prompt_logprobs, sample_logprobs = get_logprobs( logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, on_device_tensors=on_device_tensors, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) def _apply_top_k_top_p_vacc( logits: torch.Tensor, p: torch.Tensor, k: torch.Tensor, ) -> torch.Tensor: logits_sort, logits_idx = logits.sort(dim=-1, descending=False) # Apply top-k. top_k_mask = logits_sort.size(1) - k.to(torch.long) # Get all the top_k values. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) probs_sum = probs_sort.cumsum(dim=-1) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1).to(probs_sum.device) # at least one top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. logits = torch.empty_like(logits_sort).scatter_(dim=-1, index=logits_idx, src=logits_sort) return logits