313 lines
15 KiB
Python
313 lines
15 KiB
Python
from typing import Any, Dict, List, Tuple
|
|
import torch
|
|
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeStochasticBaseSampler
|
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
|
HiddenStates, SequenceGroupMetadata)
|
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|
SpeculativeScores)
|
|
from vllm.spec_decode.util import split_batch_by_proposal_len
|
|
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
|
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
|
HiddenStates, SequenceGroupMetadata,
|
|
get_all_seq_ids_and_request_ids)
|
|
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
|
create_sequence_group_output,
|
|
get_all_num_logprobs,
|
|
get_sampled_token_logprobs, nvtx_range,
|
|
split_batch_by_proposal_len)
|
|
from vllm.spec_decode.spec_decode_worker import prepare_prefill_hidden_states
|
|
|
|
from vllm.spec_decode.spec_decode_worker import logger
|
|
import os
|
|
|
|
LOG_LEVEL = os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper()
|
|
|
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
|
# If the feature combo become valid
|
|
|
|
def _verify_tokens(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
proposal_scores: SpeculativeScores,
|
|
proposals: SpeculativeProposals,
|
|
max_proposal_len: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Determine which speculative tokens are accepted using the
|
|
probabilities of each token according to the proposer and scorer models.
|
|
|
|
Returns a tuple of Tensors, one for the accepted token ids and one for
|
|
the logprobs according to the scoring model.
|
|
"""
|
|
proposal_lens_list = proposals.proposal_lens.tolist()
|
|
|
|
# vLLM currently only supports proposal lens equal to zero or the batch
|
|
# proposal len. This adds some complexity (splitting the batch into spec
|
|
# and non spec sequences) and should be removed in the future. It can be
|
|
# done by supporting per-sequence proposal lens.
|
|
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
|
seq_group_metadata_list, proposal_lens_list)
|
|
original_indices = spec_indices + non_spec_indices
|
|
|
|
# Get probabilities of target model, including bonus tokens.
|
|
proposal_verifier_probs = proposal_scores.probs
|
|
|
|
if len(non_spec_indices) == 0:
|
|
non_spec_token_ids = None
|
|
else:
|
|
# Get non-speculative sampled tokens from target model.
|
|
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
|
|
|
# # Get bonus tokens from target model.
|
|
# bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
|
|
|
# Get probabilities according to proposal method.
|
|
proposal_probs = proposals.proposal_probs
|
|
|
|
# Get proposed tokens.
|
|
proposal_token_ids = proposals.proposal_token_ids
|
|
|
|
# Sampler arguments
|
|
sampler_extra_kwargs: Dict[str, Any] = {}
|
|
if self.generators and isinstance(self.spec_decode_sampler,
|
|
SpecDecodeStochasticBaseSampler):
|
|
sampler_extra_kwargs["seeded_seqs"] = {
|
|
idx: self.generators[sgm.request_id]
|
|
for idx, sgm in enumerate(seq_group_metadata_list)
|
|
if sgm.sampling_params.seed is not None
|
|
}
|
|
if isinstance(self.spec_decode_sampler, RejectionSampler):
|
|
bonus_token_ids = proposal_scores.token_ids if len(non_spec_indices) == 0 else proposal_scores.token_ids[spec_indices, -1:]
|
|
if len(sampler_extra_kwargs) > 0 and len(sampler_extra_kwargs["seeded_seqs"]) > 0:
|
|
seeded_seqs = sampler_extra_kwargs["seeded_seqs"]
|
|
else:
|
|
seeded_seqs = None
|
|
if seeded_seqs is None:
|
|
accepted_token_ids, index = torch.vacc.rejection_sampler(
|
|
proposal_verifier_probs,
|
|
bonus_token_ids,
|
|
proposal_probs,
|
|
proposal_token_ids,
|
|
1
|
|
)
|
|
else:
|
|
accepted_token_ids, index = torch.vacc.rejection_sampler(
|
|
proposal_verifier_probs,
|
|
bonus_token_ids,
|
|
proposal_probs,
|
|
proposal_token_ids,
|
|
0,
|
|
seeded_seqs[0]
|
|
)
|
|
if LOG_LEVEL == "DEBUG":
|
|
self.spec_decode_sampler.num_accepted_tokens += index.cpu().sum()
|
|
self.spec_decode_sampler.num_draft_tokens += accepted_token_ids.shape[0] * (accepted_token_ids.shape[1] - 1)
|
|
else:
|
|
accepted_token_ids = self.spec_decode_sampler(
|
|
target_with_bonus_probs=proposal_verifier_probs,
|
|
bonus_token_ids = bonus_token_ids,
|
|
draft_probs=proposal_probs,
|
|
draft_token_ids=proposal_token_ids,
|
|
**sampler_extra_kwargs,
|
|
)
|
|
index = None
|
|
|
|
# Append output tokens from non-speculative sequences to
|
|
# the accepted token ids tensor.
|
|
if len(non_spec_indices) != 0:
|
|
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
|
1).clone()
|
|
non_spec_token_ids[:, 1:] = -1
|
|
accepted_token_ids = torch.cat(
|
|
[accepted_token_ids, non_spec_token_ids])
|
|
# # Rearrange so that results are in the order of the original seq group
|
|
# # metadata.
|
|
# accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
|
|
|
logprobs = proposal_scores.logprobs
|
|
# B x K+1 x D
|
|
hidden_states = proposal_scores.hidden_states
|
|
if hidden_states is not None:
|
|
# Only get terminal hidden states for next step
|
|
terminal_metadata = [
|
|
sg for sg in seq_group_metadata_list if sg.do_sample
|
|
]
|
|
|
|
# Contract hidden states based on accepted tokens
|
|
hs_size = hidden_states.shape[-1]
|
|
if index is None:
|
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
|
# Drop non-terminal prefill chunks hidden states.
|
|
hidden_states = hidden_states[accepted_index !=
|
|
VLLM_INVALID_TOKEN_ID]
|
|
accepted_index = accepted_index[accepted_index !=
|
|
VLLM_INVALID_TOKEN_ID]
|
|
|
|
# assert index.tolist()[0] == accepted_index.tolist()[0]
|
|
else:
|
|
accepted_index = index
|
|
assert len(accepted_index) == hidden_states.shape[0] == len(
|
|
terminal_metadata)
|
|
# index = accepted_index[:, None, None].expand(-1, 1,
|
|
# hs_size) # b x 1 x d
|
|
# second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
|
# hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
|
second_last_token_hidden_states, hidden_states = torch.vacc.rejection_sampler_update_hidden_states(hidden_states, accepted_index)
|
|
# Store hidden states from target model for subsequent decode step
|
|
self.previous_hidden_states = HiddenStates(
|
|
hidden_states, terminal_metadata,
|
|
second_last_token_hidden_states)
|
|
return accepted_token_ids, logprobs
|
|
|
|
|
|
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
|
scoring_time_ms: float,
|
|
verification_time_ms: float) -> None:
|
|
"""Log the speculative stage times. If stat logging is disabled, do
|
|
nothing.
|
|
"""
|
|
if self._disable_log_stats:
|
|
return
|
|
logger.debug(
|
|
"SpecDecodeWorker stage times: "
|
|
"average_time_per_proposal_tok_ms=%.02f "
|
|
"scoring_time_ms=%.02f verification_time_ms=%.02f",
|
|
average_time_per_proposal_tok_ms, scoring_time_ms,
|
|
verification_time_ms)
|
|
|
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
|
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
|
skip_proposer: bool) -> List[SamplerOutput]:
|
|
"""Run a single generation step without any speculation. The input is
|
|
sent to the proposer and scorer model so that the KV cache is consistent
|
|
between the two. When skip_proposer is True, the proposer model is
|
|
not called, meaning that the kv-cache in proposer for requests is not
|
|
updated, so they cannot enable spec decode in the rest decoding.
|
|
"""
|
|
|
|
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
|
assert len(sampler_output) == 1
|
|
sampler_output = sampler_output[0]
|
|
|
|
# Store hidden states from target model execution, BxD.
|
|
hidden_states = sampler_output.hidden_states
|
|
if hidden_states is not None:
|
|
# Only decodes and prefill terminal chunks need a hidden state.
|
|
seq_group_meta_with_hidden = [
|
|
sg for sg in execute_model_req.seq_group_metadata_list
|
|
if sg.do_sample
|
|
]
|
|
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
|
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
|
hidden_states = hidden_states[
|
|
torch.where(sampler_output.sampled_token_ids -
|
|
VLLM_INVALID_TOKEN_ID)[0]]
|
|
if self.previous_hidden_states is None and len(
|
|
seq_group_meta_with_hidden):
|
|
self.previous_hidden_states = HiddenStates(
|
|
hidden_states, seq_group_meta_with_hidden)
|
|
elif self.previous_hidden_states and len(
|
|
seq_group_meta_with_hidden):
|
|
self.previous_hidden_states.update(hidden_states,
|
|
seq_group_meta_with_hidden)
|
|
# self.previous_hidden_states.prune(seq_group_meta_with_hidden)
|
|
|
|
if not skip_proposer:
|
|
# We prepare the prefill hidden states here so that there no
|
|
# additional complexity in worker for spec_decode vs non_spec_decode
|
|
# flow and execute_model doesn't need additional modifications.
|
|
execute_model_req.previous_hidden_states = \
|
|
prepare_prefill_hidden_states(
|
|
sampler_output.prefill_hidden_states)
|
|
for i in range(self._num_spec_prefill_steps):
|
|
execute_model_req.spec_step_idx = i
|
|
self.proposer_worker.execute_model(execute_model_req)
|
|
|
|
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
|
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
|
if self._disable_logprobs else
|
|
[sampler_output])
|
|
|
|
# Clear device tensors from sampler output. This reduces communication
|
|
# overhead when the engine runs in a different process than the workers.
|
|
sampler_output.sampled_token_probs = None
|
|
sampler_output.sampled_token_ids = None
|
|
sampler_output.logprobs = None
|
|
return sampler_output_to_return
|
|
|
|
def _prepare_prefill_hidden_states(
|
|
prefill_hidden_states: torch.Tensor) -> HiddenStates:
|
|
# For prefill step in proposer, we run the model for N-1 tokens
|
|
# because Nth token will be processed in the first decode step. For
|
|
# N-1 tokens, the input should be 0:N-1 hidden states which should
|
|
# be concatanated with 1:N token (since output of scorer has to be
|
|
# the input for proposer). Therefore, we shift the hidden states to
|
|
# align n-1th hidden state with nth token.
|
|
#print("prefill hiddens is:", prefill_hidden_states.shape, prefill_hidden_states.dtype,prefill_hidden_states)
|
|
if prefill_hidden_states is None:
|
|
return None
|
|
from torch_vacc.vacc.custom_ops import roll_out
|
|
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
|
roll_out_buffer = None
|
|
if memory_recycler is not None:
|
|
roll_out_buffer = memory_recycler.EMBEDDING_OUT_BUFFER
|
|
|
|
rolls = roll_out(prefill_hidden_states, shifts=1, dims=0, output=roll_out_buffer)
|
|
return HiddenStates(rolls)
|
|
|
|
class SpecDecodeWorker():
|
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
|
"""Determine the number of cache blocks to use.
|
|
|
|
This is done by profiling the scorer model (which is typically the
|
|
larger of the two). Then the total memory which would be used by the
|
|
scorer cache is divided evenly between the proposer and scorer model KV,
|
|
such that the number of blocks is equal in both KV caches.
|
|
"""
|
|
num_gpu_blocks, num_cpu_blocks = (
|
|
self.scorer_worker.determine_num_available_blocks())
|
|
|
|
scorer_cache_block_size_bytes = (
|
|
self.scorer_worker.get_cache_block_size_bytes())
|
|
|
|
proposer_cache_block_size_bytes = (
|
|
self.proposer_worker.get_cache_block_size_bytes())
|
|
|
|
from vllm.utils import GiB_bytes
|
|
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
|
|
|
|
if available_kv_cache_memory ==0:
|
|
torch.vacc.empty_cache()
|
|
torch.vacc.reset_peak_memory_stats()
|
|
total_memory = torch.vacc.mem_get_info()[1]
|
|
self.scorer_worker.model_runner.profile_run()
|
|
torch.vacc.synchronize()
|
|
peak_memory = torch.vacc.max_memory_allocated()
|
|
torch.vacc.empty_cache()
|
|
torch_allocated_bytes = torch.vacc.memory_stats(
|
|
)["allocated_bytes.all.current"]
|
|
total_allocated_bytes = torch.vacc.mem_get_info(
|
|
)[1] - torch.vacc.mem_get_info()[0]
|
|
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
|
if non_torch_allocations > 0:
|
|
peak_memory += non_torch_allocations
|
|
available_kv_cache_memory=total_memory*self.scorer_worker.cache_config.gpu_memory_utilization - peak_memory
|
|
|
|
# Determine whether the current num_gpu_blocks meets the memory requirements
|
|
# based on the block_size_bytes of the score + proposer model.
|
|
scorer_proposer_cache_bytes = (scorer_cache_block_size_bytes + proposer_cache_block_size_bytes) * num_gpu_blocks
|
|
|
|
if scorer_proposer_cache_bytes < available_kv_cache_memory:
|
|
new_num_gpu_blocks = num_gpu_blocks
|
|
else:
|
|
from vllm.spec_decode.spec_decode_worker import split_num_cache_blocks_evenly
|
|
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
|
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
|
num_gpu_blocks)
|
|
# print("spec decoer 的信息为: available_kv_cache_memory", available_kv_cache_memory, "\n", \
|
|
# "scorer_proposer_cache_bytes:",scorer_proposer_cache_bytes, "\n", \
|
|
# new_num_gpu_blocks)
|
|
return new_num_gpu_blocks, num_cpu_blocks |