init
This commit is contained in:
313
vllm_vacc/vllm/spec_decode/spec_decode_worker.py
Normal file
313
vllm_vacc/vllm/spec_decode/spec_decode_worker.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import torch
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeStochasticBaseSampler
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
HiddenStates, SequenceGroupMetadata)
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScores)
|
||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.spec_decode.spec_decode_worker import prepare_prefill_hidden_states
|
||||
|
||||
from vllm.spec_decode.spec_decode_worker import logger
|
||||
import os
|
||||
|
||||
LOG_LEVEL = os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper()
|
||||
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
|
||||
def _verify_tokens(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, including bonus tokens.
|
||||
proposal_verifier_probs = proposal_scores.probs
|
||||
|
||||
if len(non_spec_indices) == 0:
|
||||
non_spec_token_ids = None
|
||||
else:
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# # Get bonus tokens from target model.
|
||||
# bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs: Dict[str, Any] = {}
|
||||
if self.generators and isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
sampler_extra_kwargs["seeded_seqs"] = {
|
||||
idx: self.generators[sgm.request_id]
|
||||
for idx, sgm in enumerate(seq_group_metadata_list)
|
||||
if sgm.sampling_params.seed is not None
|
||||
}
|
||||
if isinstance(self.spec_decode_sampler, RejectionSampler):
|
||||
bonus_token_ids = proposal_scores.token_ids if len(non_spec_indices) == 0 else proposal_scores.token_ids[spec_indices, -1:]
|
||||
if len(sampler_extra_kwargs) > 0 and len(sampler_extra_kwargs["seeded_seqs"]) > 0:
|
||||
seeded_seqs = sampler_extra_kwargs["seeded_seqs"]
|
||||
else:
|
||||
seeded_seqs = None
|
||||
if seeded_seqs is None:
|
||||
accepted_token_ids, index = torch.vacc.rejection_sampler(
|
||||
proposal_verifier_probs,
|
||||
bonus_token_ids,
|
||||
proposal_probs,
|
||||
proposal_token_ids,
|
||||
1
|
||||
)
|
||||
else:
|
||||
accepted_token_ids, index = torch.vacc.rejection_sampler(
|
||||
proposal_verifier_probs,
|
||||
bonus_token_ids,
|
||||
proposal_probs,
|
||||
proposal_token_ids,
|
||||
0,
|
||||
seeded_seqs[0]
|
||||
)
|
||||
if LOG_LEVEL == "DEBUG":
|
||||
self.spec_decode_sampler.num_accepted_tokens += index.cpu().sum()
|
||||
self.spec_decode_sampler.num_draft_tokens += accepted_token_ids.shape[0] * (accepted_token_ids.shape[1] - 1)
|
||||
else:
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_with_bonus_probs=proposal_verifier_probs,
|
||||
bonus_token_ids = bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
index = None
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
if len(non_spec_indices) != 0:
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
1).clone()
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
# # Rearrange so that results are in the order of the original seq group
|
||||
# # metadata.
|
||||
# accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
logprobs = proposal_scores.logprobs
|
||||
# B x K+1 x D
|
||||
hidden_states = proposal_scores.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only get terminal hidden states for next step
|
||||
terminal_metadata = [
|
||||
sg for sg in seq_group_metadata_list if sg.do_sample
|
||||
]
|
||||
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
if index is None:
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||
# Drop non-terminal prefill chunks hidden states.
|
||||
hidden_states = hidden_states[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
|
||||
# assert index.tolist()[0] == accepted_index.tolist()[0]
|
||||
else:
|
||||
accepted_index = index
|
||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||
terminal_metadata)
|
||||
# index = accepted_index[:, None, None].expand(-1, 1,
|
||||
# hs_size) # b x 1 x d
|
||||
# second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||
# hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
second_last_token_hidden_states, hidden_states = torch.vacc.rejection_sampler_update_hidden_states(hidden_states, accepted_index)
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, terminal_metadata,
|
||||
second_last_token_hidden_states)
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
|
||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||
scoring_time_ms: float,
|
||||
verification_time_ms: float) -> None:
|
||||
"""Log the speculative stage times. If stat logging is disabled, do
|
||||
nothing.
|
||||
"""
|
||||
if self._disable_log_stats:
|
||||
return
|
||||
logger.debug(
|
||||
"SpecDecodeWorker stage times: "
|
||||
"average_time_per_proposal_tok_ms=%.02f "
|
||||
"scoring_time_ms=%.02f verification_time_ms=%.02f",
|
||||
average_time_per_proposal_tok_ms, scoring_time_ms,
|
||||
verification_time_ms)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
"""Run a single generation step without any speculation. The input is
|
||||
sent to the proposer and scorer model so that the KV cache is consistent
|
||||
between the two. When skip_proposer is True, the proposer model is
|
||||
not called, meaning that the kv-cache in proposer for requests is not
|
||||
updated, so they cannot enable spec decode in the rest decoding.
|
||||
"""
|
||||
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Store hidden states from target model execution, BxD.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only decodes and prefill terminal chunks need a hidden state.
|
||||
seq_group_meta_with_hidden = [
|
||||
sg for sg in execute_model_req.seq_group_metadata_list
|
||||
if sg.do_sample
|
||||
]
|
||||
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
||||
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
||||
hidden_states = hidden_states[
|
||||
torch.where(sampler_output.sampled_token_ids -
|
||||
VLLM_INVALID_TOKEN_ID)[0]]
|
||||
if self.previous_hidden_states is None and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_meta_with_hidden)
|
||||
elif self.previous_hidden_states and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states.update(hidden_states,
|
||||
seq_group_meta_with_hidden)
|
||||
# self.previous_hidden_states.prune(seq_group_meta_with_hidden)
|
||||
|
||||
if not skip_proposer:
|
||||
# We prepare the prefill hidden states here so that there no
|
||||
# additional complexity in worker for spec_decode vs non_spec_decode
|
||||
# flow and execute_model doesn't need additional modifications.
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(
|
||||
sampler_output.prefill_hidden_states)
|
||||
for i in range(self._num_spec_prefill_steps):
|
||||
execute_model_req.spec_step_idx = i
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
[sampler_output])
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return sampler_output_to_return
|
||||
|
||||
def _prepare_prefill_hidden_states(
|
||||
prefill_hidden_states: torch.Tensor) -> HiddenStates:
|
||||
# For prefill step in proposer, we run the model for N-1 tokens
|
||||
# because Nth token will be processed in the first decode step. For
|
||||
# N-1 tokens, the input should be 0:N-1 hidden states which should
|
||||
# be concatanated with 1:N token (since output of scorer has to be
|
||||
# the input for proposer). Therefore, we shift the hidden states to
|
||||
# align n-1th hidden state with nth token.
|
||||
#print("prefill hiddens is:", prefill_hidden_states.shape, prefill_hidden_states.dtype,prefill_hidden_states)
|
||||
if prefill_hidden_states is None:
|
||||
return None
|
||||
from torch_vacc.vacc.custom_ops import roll_out
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
roll_out_buffer = None
|
||||
if memory_recycler is not None:
|
||||
roll_out_buffer = memory_recycler.EMBEDDING_OUT_BUFFER
|
||||
|
||||
rolls = roll_out(prefill_hidden_states, shifts=1, dims=0, output=roll_out_buffer)
|
||||
return HiddenStates(rolls)
|
||||
|
||||
class SpecDecodeWorker():
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
This is done by profiling the scorer model (which is typically the
|
||||
larger of the two). Then the total memory which would be used by the
|
||||
scorer cache is divided evenly between the proposer and scorer model KV,
|
||||
such that the number of blocks is equal in both KV caches.
|
||||
"""
|
||||
num_gpu_blocks, num_cpu_blocks = (
|
||||
self.scorer_worker.determine_num_available_blocks())
|
||||
|
||||
scorer_cache_block_size_bytes = (
|
||||
self.scorer_worker.get_cache_block_size_bytes())
|
||||
|
||||
proposer_cache_block_size_bytes = (
|
||||
self.proposer_worker.get_cache_block_size_bytes())
|
||||
|
||||
from vllm.utils import GiB_bytes
|
||||
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
|
||||
|
||||
if available_kv_cache_memory ==0:
|
||||
torch.vacc.empty_cache()
|
||||
torch.vacc.reset_peak_memory_stats()
|
||||
total_memory = torch.vacc.mem_get_info()[1]
|
||||
self.scorer_worker.model_runner.profile_run()
|
||||
torch.vacc.synchronize()
|
||||
peak_memory = torch.vacc.max_memory_allocated()
|
||||
torch.vacc.empty_cache()
|
||||
torch_allocated_bytes = torch.vacc.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.vacc.mem_get_info(
|
||||
)[1] - torch.vacc.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory=total_memory*self.scorer_worker.cache_config.gpu_memory_utilization - peak_memory
|
||||
|
||||
# Determine whether the current num_gpu_blocks meets the memory requirements
|
||||
# based on the block_size_bytes of the score + proposer model.
|
||||
scorer_proposer_cache_bytes = (scorer_cache_block_size_bytes + proposer_cache_block_size_bytes) * num_gpu_blocks
|
||||
|
||||
if scorer_proposer_cache_bytes < available_kv_cache_memory:
|
||||
new_num_gpu_blocks = num_gpu_blocks
|
||||
else:
|
||||
from vllm.spec_decode.spec_decode_worker import split_num_cache_blocks_evenly
|
||||
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||
num_gpu_blocks)
|
||||
# print("spec decoer 的信息为: available_kv_cache_memory", available_kv_cache_memory, "\n", \
|
||||
# "scorer_proposer_cache_bytes:",scorer_proposer_cache_bytes, "\n", \
|
||||
# new_num_gpu_blocks)
|
||||
return new_num_gpu_blocks, num_cpu_blocks
|
||||
Reference in New Issue
Block a user