init
This commit is contained in:
0
vllm_vacc/vllm/spec_decode/__init__.py
Normal file
0
vllm_vacc/vllm/spec_decode/__init__.py
Normal file
BIN
vllm_vacc/vllm/spec_decode/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/spec_decode/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/spec_decode/__pycache__/metrics.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/spec_decode/__pycache__/metrics.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/spec_decode/__pycache__/var.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/spec_decode/__pycache__/var.cpython-312.pyc
Normal file
Binary file not shown.
31
vllm_vacc/vllm/spec_decode/metrics.py
Normal file
31
vllm_vacc/vllm/spec_decode/metrics.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||
from the device to CPU on a non-default Torch stream.
|
||||
"""
|
||||
def _copy_rejsample_metrics_async(self):
|
||||
"""Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
import torch_vacc
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.vacc.current_stream())
|
||||
|
||||
with torch.vacc.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_accepted_tokens,
|
||||
non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.vacc.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
75
vllm_vacc/vllm/spec_decode/multi_step_worker.py
Normal file
75
vllm_vacc/vllm/spec_decode/multi_step_worker.py
Normal file
@@ -0,0 +1,75 @@
|
||||
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
from .var import *
|
||||
|
||||
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
):
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
self.worker.model_runner.return_hidden_states = True
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
self._maybe_update_previous_hidden_states(
|
||||
model_output, expanded_request)
|
||||
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# 融合算子中进行outputs相关tensor选择
|
||||
if USE_FUSED_MTP_SAMPLER:
|
||||
return model_outputs, indices_of_seq_with_bonus_tokens, True
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
313
vllm_vacc/vllm/spec_decode/spec_decode_worker.py
Normal file
313
vllm_vacc/vllm/spec_decode/spec_decode_worker.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import torch
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeStochasticBaseSampler
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
HiddenStates, SequenceGroupMetadata)
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScores)
|
||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.spec_decode.spec_decode_worker import prepare_prefill_hidden_states
|
||||
|
||||
from vllm.spec_decode.spec_decode_worker import logger
|
||||
import os
|
||||
|
||||
LOG_LEVEL = os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper()
|
||||
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
|
||||
def _verify_tokens(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, including bonus tokens.
|
||||
proposal_verifier_probs = proposal_scores.probs
|
||||
|
||||
if len(non_spec_indices) == 0:
|
||||
non_spec_token_ids = None
|
||||
else:
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# # Get bonus tokens from target model.
|
||||
# bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs: Dict[str, Any] = {}
|
||||
if self.generators and isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
sampler_extra_kwargs["seeded_seqs"] = {
|
||||
idx: self.generators[sgm.request_id]
|
||||
for idx, sgm in enumerate(seq_group_metadata_list)
|
||||
if sgm.sampling_params.seed is not None
|
||||
}
|
||||
if isinstance(self.spec_decode_sampler, RejectionSampler):
|
||||
bonus_token_ids = proposal_scores.token_ids if len(non_spec_indices) == 0 else proposal_scores.token_ids[spec_indices, -1:]
|
||||
if len(sampler_extra_kwargs) > 0 and len(sampler_extra_kwargs["seeded_seqs"]) > 0:
|
||||
seeded_seqs = sampler_extra_kwargs["seeded_seqs"]
|
||||
else:
|
||||
seeded_seqs = None
|
||||
if seeded_seqs is None:
|
||||
accepted_token_ids, index = torch.vacc.rejection_sampler(
|
||||
proposal_verifier_probs,
|
||||
bonus_token_ids,
|
||||
proposal_probs,
|
||||
proposal_token_ids,
|
||||
1
|
||||
)
|
||||
else:
|
||||
accepted_token_ids, index = torch.vacc.rejection_sampler(
|
||||
proposal_verifier_probs,
|
||||
bonus_token_ids,
|
||||
proposal_probs,
|
||||
proposal_token_ids,
|
||||
0,
|
||||
seeded_seqs[0]
|
||||
)
|
||||
if LOG_LEVEL == "DEBUG":
|
||||
self.spec_decode_sampler.num_accepted_tokens += index.cpu().sum()
|
||||
self.spec_decode_sampler.num_draft_tokens += accepted_token_ids.shape[0] * (accepted_token_ids.shape[1] - 1)
|
||||
else:
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_with_bonus_probs=proposal_verifier_probs,
|
||||
bonus_token_ids = bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
index = None
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
if len(non_spec_indices) != 0:
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
1).clone()
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
# # Rearrange so that results are in the order of the original seq group
|
||||
# # metadata.
|
||||
# accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
logprobs = proposal_scores.logprobs
|
||||
# B x K+1 x D
|
||||
hidden_states = proposal_scores.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only get terminal hidden states for next step
|
||||
terminal_metadata = [
|
||||
sg for sg in seq_group_metadata_list if sg.do_sample
|
||||
]
|
||||
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
if index is None:
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||
# Drop non-terminal prefill chunks hidden states.
|
||||
hidden_states = hidden_states[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
|
||||
# assert index.tolist()[0] == accepted_index.tolist()[0]
|
||||
else:
|
||||
accepted_index = index
|
||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||
terminal_metadata)
|
||||
# index = accepted_index[:, None, None].expand(-1, 1,
|
||||
# hs_size) # b x 1 x d
|
||||
# second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||
# hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
second_last_token_hidden_states, hidden_states = torch.vacc.rejection_sampler_update_hidden_states(hidden_states, accepted_index)
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, terminal_metadata,
|
||||
second_last_token_hidden_states)
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
|
||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||
scoring_time_ms: float,
|
||||
verification_time_ms: float) -> None:
|
||||
"""Log the speculative stage times. If stat logging is disabled, do
|
||||
nothing.
|
||||
"""
|
||||
if self._disable_log_stats:
|
||||
return
|
||||
logger.debug(
|
||||
"SpecDecodeWorker stage times: "
|
||||
"average_time_per_proposal_tok_ms=%.02f "
|
||||
"scoring_time_ms=%.02f verification_time_ms=%.02f",
|
||||
average_time_per_proposal_tok_ms, scoring_time_ms,
|
||||
verification_time_ms)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
"""Run a single generation step without any speculation. The input is
|
||||
sent to the proposer and scorer model so that the KV cache is consistent
|
||||
between the two. When skip_proposer is True, the proposer model is
|
||||
not called, meaning that the kv-cache in proposer for requests is not
|
||||
updated, so they cannot enable spec decode in the rest decoding.
|
||||
"""
|
||||
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Store hidden states from target model execution, BxD.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only decodes and prefill terminal chunks need a hidden state.
|
||||
seq_group_meta_with_hidden = [
|
||||
sg for sg in execute_model_req.seq_group_metadata_list
|
||||
if sg.do_sample
|
||||
]
|
||||
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
||||
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
||||
hidden_states = hidden_states[
|
||||
torch.where(sampler_output.sampled_token_ids -
|
||||
VLLM_INVALID_TOKEN_ID)[0]]
|
||||
if self.previous_hidden_states is None and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_meta_with_hidden)
|
||||
elif self.previous_hidden_states and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states.update(hidden_states,
|
||||
seq_group_meta_with_hidden)
|
||||
# self.previous_hidden_states.prune(seq_group_meta_with_hidden)
|
||||
|
||||
if not skip_proposer:
|
||||
# We prepare the prefill hidden states here so that there no
|
||||
# additional complexity in worker for spec_decode vs non_spec_decode
|
||||
# flow and execute_model doesn't need additional modifications.
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(
|
||||
sampler_output.prefill_hidden_states)
|
||||
for i in range(self._num_spec_prefill_steps):
|
||||
execute_model_req.spec_step_idx = i
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
[sampler_output])
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return sampler_output_to_return
|
||||
|
||||
def _prepare_prefill_hidden_states(
|
||||
prefill_hidden_states: torch.Tensor) -> HiddenStates:
|
||||
# For prefill step in proposer, we run the model for N-1 tokens
|
||||
# because Nth token will be processed in the first decode step. For
|
||||
# N-1 tokens, the input should be 0:N-1 hidden states which should
|
||||
# be concatanated with 1:N token (since output of scorer has to be
|
||||
# the input for proposer). Therefore, we shift the hidden states to
|
||||
# align n-1th hidden state with nth token.
|
||||
#print("prefill hiddens is:", prefill_hidden_states.shape, prefill_hidden_states.dtype,prefill_hidden_states)
|
||||
if prefill_hidden_states is None:
|
||||
return None
|
||||
from torch_vacc.vacc.custom_ops import roll_out
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
roll_out_buffer = None
|
||||
if memory_recycler is not None:
|
||||
roll_out_buffer = memory_recycler.EMBEDDING_OUT_BUFFER
|
||||
|
||||
rolls = roll_out(prefill_hidden_states, shifts=1, dims=0, output=roll_out_buffer)
|
||||
return HiddenStates(rolls)
|
||||
|
||||
class SpecDecodeWorker():
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
This is done by profiling the scorer model (which is typically the
|
||||
larger of the two). Then the total memory which would be used by the
|
||||
scorer cache is divided evenly between the proposer and scorer model KV,
|
||||
such that the number of blocks is equal in both KV caches.
|
||||
"""
|
||||
num_gpu_blocks, num_cpu_blocks = (
|
||||
self.scorer_worker.determine_num_available_blocks())
|
||||
|
||||
scorer_cache_block_size_bytes = (
|
||||
self.scorer_worker.get_cache_block_size_bytes())
|
||||
|
||||
proposer_cache_block_size_bytes = (
|
||||
self.proposer_worker.get_cache_block_size_bytes())
|
||||
|
||||
from vllm.utils import GiB_bytes
|
||||
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
|
||||
|
||||
if available_kv_cache_memory ==0:
|
||||
torch.vacc.empty_cache()
|
||||
torch.vacc.reset_peak_memory_stats()
|
||||
total_memory = torch.vacc.mem_get_info()[1]
|
||||
self.scorer_worker.model_runner.profile_run()
|
||||
torch.vacc.synchronize()
|
||||
peak_memory = torch.vacc.max_memory_allocated()
|
||||
torch.vacc.empty_cache()
|
||||
torch_allocated_bytes = torch.vacc.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.vacc.mem_get_info(
|
||||
)[1] - torch.vacc.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory=total_memory*self.scorer_worker.cache_config.gpu_memory_utilization - peak_memory
|
||||
|
||||
# Determine whether the current num_gpu_blocks meets the memory requirements
|
||||
# based on the block_size_bytes of the score + proposer model.
|
||||
scorer_proposer_cache_bytes = (scorer_cache_block_size_bytes + proposer_cache_block_size_bytes) * num_gpu_blocks
|
||||
|
||||
if scorer_proposer_cache_bytes < available_kv_cache_memory:
|
||||
new_num_gpu_blocks = num_gpu_blocks
|
||||
else:
|
||||
from vllm.spec_decode.spec_decode_worker import split_num_cache_blocks_evenly
|
||||
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||
num_gpu_blocks)
|
||||
# print("spec decoer 的信息为: available_kv_cache_memory", available_kv_cache_memory, "\n", \
|
||||
# "scorer_proposer_cache_bytes:",scorer_proposer_cache_bytes, "\n", \
|
||||
# new_num_gpu_blocks)
|
||||
return new_num_gpu_blocks, num_cpu_blocks
|
||||
118
vllm_vacc/vllm/spec_decode/top1_proposer.py
Normal file
118
vllm_vacc/vllm/spec_decode/top1_proposer.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Set
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
import torch
|
||||
|
||||
from .var import *
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
hidden_states = execute_model_req.previous_hidden_states
|
||||
if hidden_states is not None:
|
||||
hidden_states.prune(nonzero_proposal_len_seqs)
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
num_lookahead_slots=proposal_len,
|
||||
previous_hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
#思路,将sampler out和merge output合成1个OP
|
||||
#remove no proposal_seqs 该流程没必要,直接全是草稿输出,不存在混合
|
||||
#功能1: 筛选 indices
|
||||
#功能2: 返回 tensor
|
||||
#功能3: 对tensor进行transpose
|
||||
if USE_FUSED_MTP_SAMPLER:
|
||||
sampler_outputs, token_indices, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
|
||||
outputs = sampler_outputs[0].outputs
|
||||
token_probs = sampler_outputs[0].sampled_token_probs
|
||||
token_ids = sampler_outputs[0].sampled_token_ids
|
||||
|
||||
bs = len(seq_group_metadata_list)
|
||||
proposal_lens = torch.ones((bs), dtype=torch.int, device=self._worker.device)
|
||||
|
||||
proposal_tokens = token_ids[token_indices]
|
||||
proposal_probs = token_probs[token_indices]
|
||||
|
||||
s0, s1 = proposal_probs.shape
|
||||
proposal_probs = proposal_probs.view(s0, 1, s1)
|
||||
|
||||
# 筛选indices,构建新的SamplerOut
|
||||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=outputs
|
||||
is None)
|
||||
return proposals
|
||||
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._remove_no_proposal_seqs(proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
transposed)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=maybe_sampler_output
|
||||
is None)
|
||||
return proposals
|
||||
3
vllm_vacc/vllm/spec_decode/var.py
Normal file
3
vllm_vacc/vllm/spec_decode/var.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# qwen fused attention
|
||||
import os
|
||||
USE_FUSED_MTP_SAMPLER = int(os.getenv("USE_FUSED_MTP_SAMPLER", 1))
|
||||
Reference in New Issue
Block a user