Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/spec_decode/spec_decode_worker.py
2026-04-02 04:55:00 +00:00

313 lines
15 KiB
Python

from typing import Any, Dict, List, Tuple
import torch
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeStochasticBaseSampler
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
HiddenStates, SequenceGroupMetadata)
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScores)
from vllm.spec_decode.util import split_batch_by_proposal_len
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.util import (Timer, create_logprobs_output,
create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.spec_decode.spec_decode_worker import prepare_prefill_hidden_states
from vllm.spec_decode.spec_decode_worker import logger
import os
LOG_LEVEL = os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper()
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens.
proposal_verifier_probs = proposal_scores.probs
if len(non_spec_indices) == 0:
non_spec_token_ids = None
else:
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# # Get bonus tokens from target model.
# bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids
# Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
if isinstance(self.spec_decode_sampler, RejectionSampler):
bonus_token_ids = proposal_scores.token_ids if len(non_spec_indices) == 0 else proposal_scores.token_ids[spec_indices, -1:]
if len(sampler_extra_kwargs) > 0 and len(sampler_extra_kwargs["seeded_seqs"]) > 0:
seeded_seqs = sampler_extra_kwargs["seeded_seqs"]
else:
seeded_seqs = None
if seeded_seqs is None:
accepted_token_ids, index = torch.vacc.rejection_sampler(
proposal_verifier_probs,
bonus_token_ids,
proposal_probs,
proposal_token_ids,
1
)
else:
accepted_token_ids, index = torch.vacc.rejection_sampler(
proposal_verifier_probs,
bonus_token_ids,
proposal_probs,
proposal_token_ids,
0,
seeded_seqs[0]
)
if LOG_LEVEL == "DEBUG":
self.spec_decode_sampler.num_accepted_tokens += index.cpu().sum()
self.spec_decode_sampler.num_draft_tokens += accepted_token_ids.shape[0] * (accepted_token_ids.shape[1] - 1)
else:
accepted_token_ids = self.spec_decode_sampler(
target_with_bonus_probs=proposal_verifier_probs,
bonus_token_ids = bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
index = None
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if len(non_spec_indices) != 0:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
# # Rearrange so that results are in the order of the original seq group
# # metadata.
# accepted_token_ids[original_indices] = accepted_token_ids.clone()
logprobs = proposal_scores.logprobs
# B x K+1 x D
hidden_states = proposal_scores.hidden_states
if hidden_states is not None:
# Only get terminal hidden states for next step
terminal_metadata = [
sg for sg in seq_group_metadata_list if sg.do_sample
]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
if index is None:
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
# Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[accepted_index !=
VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[accepted_index !=
VLLM_INVALID_TOKEN_ID]
# assert index.tolist()[0] == accepted_index.tolist()[0]
else:
accepted_index = index
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
# index = accepted_index[:, None, None].expand(-1, 1,
# hs_size) # b x 1 x d
# second_last_token_hidden_states = hidden_states[:, -2] # b x d
# hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
second_last_token_hidden_states, hidden_states = torch.vacc.rejection_sampler_update_hidden_states(hidden_states, accepted_index)
# Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates(
hidden_states, terminal_metadata,
second_last_token_hidden_states)
return accepted_token_ids, logprobs
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
scoring_time_ms: float,
verification_time_ms: float) -> None:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if self._disable_log_stats:
return
logger.debug(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f",
average_time_per_proposal_tok_ms, scoring_time_ms,
verification_time_ms)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Store hidden states from target model execution, BxD.
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden = [
sg for sg in execute_model_req.seq_group_metadata_list
if sg.do_sample
]
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
# self.previous_hidden_states.prune(seq_group_meta_with_hidden)
if not skip_proposer:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
for i in range(self._num_spec_prefill_steps):
execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
[sampler_output])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return sampler_output_to_return
def _prepare_prefill_hidden_states(
prefill_hidden_states: torch.Tensor) -> HiddenStates:
# For prefill step in proposer, we run the model for N-1 tokens
# because Nth token will be processed in the first decode step. For
# N-1 tokens, the input should be 0:N-1 hidden states which should
# be concatanated with 1:N token (since output of scorer has to be
# the input for proposer). Therefore, we shift the hidden states to
# align n-1th hidden state with nth token.
#print("prefill hiddens is:", prefill_hidden_states.shape, prefill_hidden_states.dtype,prefill_hidden_states)
if prefill_hidden_states is None:
return None
from torch_vacc.vacc.custom_ops import roll_out
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
roll_out_buffer = None
if memory_recycler is not None:
roll_out_buffer = memory_recycler.EMBEDDING_OUT_BUFFER
rolls = roll_out(prefill_hidden_states, shifts=1, dims=0, output=roll_out_buffer)
return HiddenStates(rolls)
class SpecDecodeWorker():
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.determine_num_available_blocks())
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes())
from vllm.utils import GiB_bytes
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
if available_kv_cache_memory ==0:
torch.vacc.empty_cache()
torch.vacc.reset_peak_memory_stats()
total_memory = torch.vacc.mem_get_info()[1]
self.scorer_worker.model_runner.profile_run()
torch.vacc.synchronize()
peak_memory = torch.vacc.max_memory_allocated()
torch.vacc.empty_cache()
torch_allocated_bytes = torch.vacc.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.vacc.mem_get_info(
)[1] - torch.vacc.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory=total_memory*self.scorer_worker.cache_config.gpu_memory_utilization - peak_memory
# Determine whether the current num_gpu_blocks meets the memory requirements
# based on the block_size_bytes of the score + proposer model.
scorer_proposer_cache_bytes = (scorer_cache_block_size_bytes + proposer_cache_block_size_bytes) * num_gpu_blocks
if scorer_proposer_cache_bytes < available_kv_cache_memory:
new_num_gpu_blocks = num_gpu_blocks
else:
from vllm.spec_decode.spec_decode_worker import split_num_cache_blocks_evenly
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
# print("spec decoer 的信息为: available_kv_cache_memory", available_kv_cache_memory, "\n", \
# "scorer_proposer_cache_bytes:",scorer_proposer_cache_bytes, "\n", \
# new_num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks