from typing import Any, Dict, List, Tuple import torch from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeStochasticBaseSampler from vllm.sequence import (VLLM_INVALID_TOKEN_ID, HiddenStates, SequenceGroupMetadata) from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScores) from vllm.spec_decode.util import split_batch_by_proposal_len from vllm.worker.worker_base import LoRANotSupportedWorkerBase from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids) from vllm.spec_decode.util import (Timer, create_logprobs_output, create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.spec_decode.spec_decode_worker import prepare_prefill_hidden_states from vllm.spec_decode.spec_decode_worker import logger import os LOG_LEVEL = os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper() # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid def _verify_tokens( self, seq_group_metadata_list: List[SequenceGroupMetadata], proposal_scores: SpeculativeScores, proposals: SpeculativeProposals, max_proposal_len: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Determine which speculative tokens are accepted using the probabilities of each token according to the proposer and scorer models. Returns a tuple of Tensors, one for the accepted token ids and one for the logprobs according to the scoring model. """ proposal_lens_list = proposals.proposal_lens.tolist() # vLLM currently only supports proposal lens equal to zero or the batch # proposal len. This adds some complexity (splitting the batch into spec # and non spec sequences) and should be removed in the future. It can be # done by supporting per-sequence proposal lens. (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices # Get probabilities of target model, including bonus tokens. proposal_verifier_probs = proposal_scores.probs if len(non_spec_indices) == 0: non_spec_token_ids = None else: # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] # # Get bonus tokens from target model. # bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] # Get probabilities according to proposal method. proposal_probs = proposals.proposal_probs # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids # Sampler arguments sampler_extra_kwargs: Dict[str, Any] = {} if self.generators and isinstance(self.spec_decode_sampler, SpecDecodeStochasticBaseSampler): sampler_extra_kwargs["seeded_seqs"] = { idx: self.generators[sgm.request_id] for idx, sgm in enumerate(seq_group_metadata_list) if sgm.sampling_params.seed is not None } if isinstance(self.spec_decode_sampler, RejectionSampler): bonus_token_ids = proposal_scores.token_ids if len(non_spec_indices) == 0 else proposal_scores.token_ids[spec_indices, -1:] if len(sampler_extra_kwargs) > 0 and len(sampler_extra_kwargs["seeded_seqs"]) > 0: seeded_seqs = sampler_extra_kwargs["seeded_seqs"] else: seeded_seqs = None if seeded_seqs is None: accepted_token_ids, index = torch.vacc.rejection_sampler( proposal_verifier_probs, bonus_token_ids, proposal_probs, proposal_token_ids, 1 ) else: accepted_token_ids, index = torch.vacc.rejection_sampler( proposal_verifier_probs, bonus_token_ids, proposal_probs, proposal_token_ids, 0, seeded_seqs[0] ) if LOG_LEVEL == "DEBUG": self.spec_decode_sampler.num_accepted_tokens += index.cpu().sum() self.spec_decode_sampler.num_draft_tokens += accepted_token_ids.shape[0] * (accepted_token_ids.shape[1] - 1) else: accepted_token_ids = self.spec_decode_sampler( target_with_bonus_probs=proposal_verifier_probs, bonus_token_ids = bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, **sampler_extra_kwargs, ) index = None # Append output tokens from non-speculative sequences to # the accepted token ids tensor. if len(non_spec_indices) != 0: non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + 1).clone() non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) # # Rearrange so that results are in the order of the original seq group # # metadata. # accepted_token_ids[original_indices] = accepted_token_ids.clone() logprobs = proposal_scores.logprobs # B x K+1 x D hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Only get terminal hidden states for next step terminal_metadata = [ sg for sg in seq_group_metadata_list if sg.do_sample ] # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] if index is None: accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b # Drop non-terminal prefill chunks hidden states. hidden_states = hidden_states[accepted_index != VLLM_INVALID_TOKEN_ID] accepted_index = accepted_index[accepted_index != VLLM_INVALID_TOKEN_ID] # assert index.tolist()[0] == accepted_index.tolist()[0] else: accepted_index = index assert len(accepted_index) == hidden_states.shape[0] == len( terminal_metadata) # index = accepted_index[:, None, None].expand(-1, 1, # hs_size) # b x 1 x d # second_last_token_hidden_states = hidden_states[:, -2] # b x d # hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d second_last_token_hidden_states, hidden_states = torch.vacc.rejection_sampler_update_hidden_states(hidden_states, accepted_index) # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates( hidden_states, terminal_metadata, second_last_token_hidden_states) return accepted_token_ids, logprobs def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, scoring_time_ms: float, verification_time_ms: float) -> None: """Log the speculative stage times. If stat logging is disabled, do nothing. """ if self._disable_log_stats: return logger.debug( "SpecDecodeWorker stage times: " "average_time_per_proposal_tok_ms=%.02f " "scoring_time_ms=%.02f verification_time_ms=%.02f", average_time_per_proposal_tok_ms, scoring_time_ms, verification_time_ms) @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, skip_proposer: bool) -> List[SamplerOutput]: """Run a single generation step without any speculation. The input is sent to the proposer and scorer model so that the KV cache is consistent between the two. When skip_proposer is True, the proposer model is not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] # Store hidden states from target model execution, BxD. hidden_states = sampler_output.hidden_states if hidden_states is not None: # Only decodes and prefill terminal chunks need a hidden state. seq_group_meta_with_hidden = [ sg for sg in execute_model_req.seq_group_metadata_list if sg.do_sample ] if any(seq.is_prompt for seq in seq_group_meta_with_hidden): # Drop hidden_states with no prediction (eg non-terminal chunks) hidden_states = hidden_states[ torch.where(sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None and len( seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates( hidden_states, seq_group_meta_with_hidden) elif self.previous_hidden_states and len( seq_group_meta_with_hidden): self.previous_hidden_states.update(hidden_states, seq_group_meta_with_hidden) # self.previous_hidden_states.prune(seq_group_meta_with_hidden) if not skip_proposer: # We prepare the prefill hidden states here so that there no # additional complexity in worker for spec_decode vs non_spec_decode # flow and execute_model doesn't need additional modifications. execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) for i in range(self._num_spec_prefill_steps): execute_model_req.spec_step_idx = i self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else [sampler_output]) # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. sampler_output.sampled_token_probs = None sampler_output.sampled_token_ids = None sampler_output.logprobs = None return sampler_output_to_return def _prepare_prefill_hidden_states( prefill_hidden_states: torch.Tensor) -> HiddenStates: # For prefill step in proposer, we run the model for N-1 tokens # because Nth token will be processed in the first decode step. For # N-1 tokens, the input should be 0:N-1 hidden states which should # be concatanated with 1:N token (since output of scorer has to be # the input for proposer). Therefore, we shift the hidden states to # align n-1th hidden state with nth token. #print("prefill hiddens is:", prefill_hidden_states.shape, prefill_hidden_states.dtype,prefill_hidden_states) if prefill_hidden_states is None: return None from torch_vacc.vacc.custom_ops import roll_out from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler roll_out_buffer = None if memory_recycler is not None: roll_out_buffer = memory_recycler.EMBEDDING_OUT_BUFFER rolls = roll_out(prefill_hidden_states, shifts=1, dims=0, output=roll_out_buffer) return HiddenStates(rolls) class SpecDecodeWorker(): def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. This is done by profiling the scorer model (which is typically the larger of the two). Then the total memory which would be used by the scorer cache is divided evenly between the proposer and scorer model KV, such that the number of blocks is equal in both KV caches. """ num_gpu_blocks, num_cpu_blocks = ( self.scorer_worker.determine_num_available_blocks()) scorer_cache_block_size_bytes = ( self.scorer_worker.get_cache_block_size_bytes()) proposer_cache_block_size_bytes = ( self.proposer_worker.get_cache_block_size_bytes()) from vllm.utils import GiB_bytes available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes if available_kv_cache_memory ==0: torch.vacc.empty_cache() torch.vacc.reset_peak_memory_stats() total_memory = torch.vacc.mem_get_info()[1] self.scorer_worker.model_runner.profile_run() torch.vacc.synchronize() peak_memory = torch.vacc.max_memory_allocated() torch.vacc.empty_cache() torch_allocated_bytes = torch.vacc.memory_stats( )["allocated_bytes.all.current"] total_allocated_bytes = torch.vacc.mem_get_info( )[1] - torch.vacc.mem_get_info()[0] non_torch_allocations = total_allocated_bytes - torch_allocated_bytes if non_torch_allocations > 0: peak_memory += non_torch_allocations available_kv_cache_memory=total_memory*self.scorer_worker.cache_config.gpu_memory_utilization - peak_memory # Determine whether the current num_gpu_blocks meets the memory requirements # based on the block_size_bytes of the score + proposer model. scorer_proposer_cache_bytes = (scorer_cache_block_size_bytes + proposer_cache_block_size_bytes) * num_gpu_blocks if scorer_proposer_cache_bytes < available_kv_cache_memory: new_num_gpu_blocks = num_gpu_blocks else: from vllm.spec_decode.spec_decode_worker import split_num_cache_blocks_evenly new_num_gpu_blocks = split_num_cache_blocks_evenly( scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, num_gpu_blocks) # print("spec decoer 的信息为: available_kv_cache_memory", available_kv_cache_memory, "\n", \ # "scorer_proposer_cache_bytes:",scorer_proposer_cache_bytes, "\n", \ # new_num_gpu_blocks) return new_num_gpu_blocks, num_cpu_blocks