1606 lines
77 KiB
Python
1606 lines
77 KiB
Python
|
||
import copy
|
||
import gc
|
||
import time
|
||
import weakref
|
||
from contextlib import contextmanager
|
||
|
||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast
|
||
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||
from vllm.utils.jsontree import json_map_leaves
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.distributed
|
||
import torch.nn as nn
|
||
from tqdm import tqdm
|
||
from typing_extensions import TypeAlias
|
||
|
||
import vllm.envs as envs
|
||
from vllm.attention import AttentionType, get_attn_backend
|
||
from vllm.attention.backends.abstract import AttentionBackend
|
||
from vllm.attention.layer import Attention
|
||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||
|
||
from vllm.compilation.counter import compilation_counter
|
||
from vllm.config import (CompilationLevel, VllmConfig,
|
||
get_layers_from_vllm_config)
|
||
from vllm.distributed.eplb.eplb_state import EplbState
|
||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||
has_kv_transfer_group)
|
||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||
from vllm.distributed.parallel_state import (
|
||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank, get_tensor_model_parallel_rank,
|
||
prepare_communication_buffer_for_model)
|
||
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||
set_forward_context)
|
||
from vllm.logger import init_logger
|
||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||
from vllm.model_executor.models import supports_multimodal
|
||
from vllm.model_executor.models.interfaces_base import (
|
||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||
from vllm.pooling_params import PoolingParams
|
||
from vllm.sampling_params import SamplingType
|
||
from vllm.sequence import IntermediateTensors
|
||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||
check_use_alibi, get_dtype_size,
|
||
is_pin_memory_available, round_up)
|
||
from vllm.model_executor.model_loader import get_model
|
||
|
||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata)
|
||
from vllm.v1.attention.backends.utils import (
|
||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||
create_fast_prefill_custom_backend,
|
||
reorder_batch_to_split_decodes_and_prefills, split_attn_metadata)
|
||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||
from vllm.v1.worker.cpu_model_runner import _torch_cuda_wrapper
|
||
from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
|
||
ubatch_split)
|
||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput
|
||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||
ModelRunnerOutput, PoolerOutput, SamplerOutput)
|
||
from vllm.v1.sample.metadata import SamplingMetadata
|
||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||
from vllm.v1.worker.block_table import BlockTable
|
||
from vllm.v1.core.sched.output import SchedulerOutput
|
||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||
ChunkedLocalAttentionSpec,
|
||
CrossAttentionSpec,
|
||
EncoderOnlyAttentionSpec,
|
||
FullAttentionSpec, KVCacheConfig,
|
||
KVCacheGroupSpec, KVCacheSpec,
|
||
MambaSpec, SlidingWindowSpec)
|
||
from vllm.utils.jsontree import json_map_leaves
|
||
# from vllm.sample.logits_processor import LogitsProcessorManager
|
||
# from vllm.worker.utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||
# sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||
|
||
|
||
|
||
from vacc_tools.trace_logger import get_trace_api
|
||
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||
get_trace_api("deepseek")
|
||
)
|
||
|
||
logger = init_logger(__name__)
|
||
|
||
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
||
# list when ubatching is enabled
|
||
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||
AttnMetadataDict]
|
||
|
||
def draft_model_eagle_load_model(draft_model, target_model: nn.Module) -> None:
|
||
draft_model_config = \
|
||
draft_model.vllm_config.speculative_config.draft_model_config
|
||
target_attn_layer_names = set(
|
||
get_layers_from_vllm_config(draft_model.vllm_config, Attention).keys())
|
||
|
||
# from vllm.compilation.backends import set_model_tag
|
||
# with set_model_tag("eagle_head"):
|
||
draft_model.model = get_model(vllm_config=draft_model.vllm_config,
|
||
model_config=draft_model_config)
|
||
|
||
draft_attn_layer_names = (
|
||
get_layers_from_vllm_config(draft_model.vllm_config, Attention).keys() -
|
||
target_attn_layer_names)
|
||
|
||
draft_model.attn_layer_names = list(draft_attn_layer_names)
|
||
|
||
if supports_multimodal(target_model):
|
||
# handle multimodality
|
||
draft_model.model.config.image_token_index = (
|
||
target_model.config.image_token_index)
|
||
target_language_model = target_model.get_language_model()
|
||
else:
|
||
target_language_model = target_model
|
||
# share embed_tokens with the target model if needed
|
||
if get_pp_group().world_size == 1 \
|
||
and draft_model.model.model.embed_tokens.weight.shape \
|
||
== target_language_model.model.embed_tokens.weight.shape:
|
||
logger.info(
|
||
"Assuming the EAGLE head shares the same vocab embedding" \
|
||
" with the target model."
|
||
)
|
||
del draft_model.model.model.embed_tokens
|
||
draft_model.model.model.embed_tokens = (
|
||
target_language_model.model.embed_tokens)
|
||
else:
|
||
logger.info(
|
||
"The EAGLE head's vocab embedding will be loaded separately" \
|
||
" from the target model."
|
||
)
|
||
|
||
# share lm_head with the target model if needed
|
||
# some model definition do not define lm_head explicitly
|
||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||
if draft_model.vllm_config.speculative_config.method != "eagle3" and \
|
||
hasattr(target_language_model, "lm_head"):
|
||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||
draft_model.model.lm_head = target_language_model.lm_head
|
||
|
||
class VACCModelRunner(GPUModelRunner):
|
||
|
||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||
with _torch_cuda_wrapper():
|
||
super().__init__(vllm_config, device)
|
||
# assert self.speculative_config is None, "spec decode is not supported."
|
||
self.use_cuda_graph = False
|
||
self.cascade_attn_enabled = False
|
||
# remove _postprocess_tenosrs to avoid gpu->cpu
|
||
# self._postprocess_tenosrs()
|
||
self.positions = self._make_buffer(self.max_num_tokens,
|
||
dtype=torch.int32)
|
||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||
dtype=torch.int32)
|
||
|
||
# self.slot_mapping = self.slot_mapping.to(torch.int32)
|
||
self.seq_lens = [0 for i in range(self.max_num_reqs)]
|
||
|
||
if self.speculative_config is not None:
|
||
self.speculative_config.disable_padded_drafter_batch = True
|
||
|
||
self.deepstack_input_embeds = None
|
||
if not (self.supports_mm_inputs and get_pp_group().is_first_rank and not self.model_config.is_encoder_decoder):
|
||
if not self.enable_prompt_embeds: # --enable-prompt-embeds
|
||
try:
|
||
self.inputs_embeds = None
|
||
torch.vacc.empty_cache()
|
||
except Exception as e:
|
||
print(f"remove self.inputs_embeds cache fail: {e}")
|
||
if self.uses_mrope:
|
||
self.mrope_positions = self._make_buffer(
|
||
(3, self.max_num_tokens + 1), dtype=torch.int32)
|
||
|
||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||
"""Update the cached states and the persistent batch with the scheduler
|
||
output.
|
||
|
||
The updated states are used by the `_prepare_inputs` function to create
|
||
the input GPU tensors for the model.
|
||
|
||
The SamplingMetadata is updated and copied to the GPU if there is a
|
||
new/resumed/paused/finished request in the batch.
|
||
"""
|
||
# Remove finished requests from the cached states.
|
||
for req_id in scheduler_output.finished_req_ids:
|
||
self.requests.pop(req_id, None)
|
||
# Remove the finished requests from the persistent batch.
|
||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||
# then resubmitted with the same ID. In this case, we treat them as two
|
||
# distinct requests - clearing the cached states for the first request
|
||
# and handling the second as a new request.
|
||
for req_id in scheduler_output.finished_req_ids:
|
||
self.input_batch.remove_request(req_id)
|
||
|
||
# Free the cached encoder outputs.
|
||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||
self.encoder_cache.pop(mm_hash, None)
|
||
|
||
# Remove the unscheduled requests from the persistent batch.
|
||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||
# or running requests that are not scheduled in this step. We remove
|
||
# them from the persistent batch but keep their cached states since
|
||
# they will be scheduled again sometime in the future.
|
||
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
||
# NOTE(woosuk): The persistent batch optimization assumes that
|
||
# consecutive batches contain mostly the same requests. If batches
|
||
# have low request overlap (e.g., alternating between two distinct
|
||
# sets of requests), this optimization becomes very inefficient.
|
||
for req_id in unscheduled_req_ids:
|
||
self.input_batch.remove_request(req_id)
|
||
|
||
reqs_to_add: list[CachedRequestState] = []
|
||
# Add new requests to the cached states.
|
||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||
req_id = new_req_data.req_id
|
||
sampling_params = new_req_data.sampling_params
|
||
pooling_params = new_req_data.pooling_params
|
||
|
||
if sampling_params and \
|
||
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||
generator = torch.Generator(device=self.device)
|
||
generator.manual_seed(sampling_params.seed)
|
||
else:
|
||
generator = None
|
||
|
||
if self.is_pooling_model:
|
||
assert pooling_params is not None
|
||
task = pooling_params.task
|
||
assert task is not None, "You did not set `task` in the API"
|
||
|
||
model = cast(VllmModelForPooling, self.get_model())
|
||
to_update = model.pooler.get_pooling_updates(task)
|
||
to_update.apply(pooling_params)
|
||
|
||
req_state = CachedRequestState(
|
||
req_id=req_id,
|
||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||
prompt_embeds=new_req_data.prompt_embeds,
|
||
deepstack_input_embeds=new_req_data.deepstack_input_embeds,
|
||
mm_features=new_req_data.mm_features,
|
||
sampling_params=sampling_params,
|
||
pooling_params=pooling_params,
|
||
generator=generator,
|
||
block_ids=new_req_data.block_ids,
|
||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||
output_token_ids=[],
|
||
lora_request=new_req_data.lora_request,
|
||
)
|
||
self.requests[req_id] = req_state
|
||
|
||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||
if self.uses_mrope:
|
||
self._init_mrope_positions(req_state)
|
||
|
||
reqs_to_add.append(req_state)
|
||
|
||
# Update the states of the running/resumed requests.
|
||
is_last_rank = get_pp_group().is_last_rank
|
||
req_data = scheduler_output.scheduled_cached_reqs
|
||
for i, req_id in enumerate(req_data.req_ids):
|
||
req_state = self.requests[req_id]
|
||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||
new_block_ids = req_data.new_block_ids[i]
|
||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||
|
||
# Update the cached states.
|
||
req_state.num_computed_tokens = num_computed_tokens
|
||
|
||
if not is_last_rank:
|
||
# When using PP, the scheduler sends the sampled tokens back,
|
||
# because there's no direct communication between the first-
|
||
# stage worker and the last-stage worker.
|
||
new_token_ids = req_data.new_token_ids[i]
|
||
# Add the sampled token(s) from the previous step (if any).
|
||
# This doesn't include "unverified" tokens like spec tokens.
|
||
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
||
req_state.num_tokens)
|
||
if num_new_tokens == 1:
|
||
# Avoid slicing list in most common case.
|
||
req_state.output_token_ids.append(new_token_ids[-1])
|
||
elif num_new_tokens > 0:
|
||
req_state.output_token_ids.extend(
|
||
new_token_ids[-num_new_tokens:])
|
||
|
||
# Update the block IDs.
|
||
if not resumed_from_preemption:
|
||
if new_block_ids is not None:
|
||
# Append the new blocks to the existing block IDs.
|
||
for block_ids, new_ids in zip(req_state.block_ids,
|
||
new_block_ids):
|
||
block_ids.extend(new_ids)
|
||
else:
|
||
assert new_block_ids is not None
|
||
# The request is resumed from preemption.
|
||
# Replace the existing block IDs with the new ones.
|
||
req_state.block_ids = new_block_ids
|
||
|
||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||
if req_index is None:
|
||
# The request is not in the persistent batch.
|
||
# The request was either preempted and resumed later, or was not
|
||
# scheduled in the previous step and needs to be added again.
|
||
reqs_to_add.append(req_state)
|
||
continue
|
||
|
||
# Update the persistent batch.
|
||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||
num_computed_tokens)
|
||
if new_block_ids is not None:
|
||
self.input_batch.block_table.append_row(
|
||
new_block_ids, req_index)
|
||
|
||
# For the last rank, we don't need to update the token_ids_cpu
|
||
# because the sampled tokens are already cached.
|
||
if not is_last_rank:
|
||
# Add new_token_ids to token_ids_cpu.
|
||
start_token_index = num_computed_tokens
|
||
end_token_index = num_computed_tokens + len(new_token_ids)
|
||
self.input_batch.token_ids_cpu[
|
||
req_index,
|
||
start_token_index:end_token_index] = new_token_ids
|
||
self.input_batch.num_tokens_no_spec[
|
||
req_index] = end_token_index
|
||
self.input_batch.num_tokens[req_index] = end_token_index
|
||
|
||
# Add spec_token_ids to token_ids_cpu.
|
||
spec_token_ids = (
|
||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||
if spec_token_ids:
|
||
num_spec_tokens = len(spec_token_ids)
|
||
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
||
end_token_index = start_index + num_spec_tokens
|
||
self.input_batch.token_ids_cpu[
|
||
req_index, start_index:end_token_index] = spec_token_ids
|
||
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
||
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
||
|
||
# Add the new or resumed requests to the persistent batch.
|
||
# The smaller empty indices are filled first.
|
||
for request in reqs_to_add:
|
||
self.input_batch.add_request(request)
|
||
|
||
# Condense the batched states if there are gaps left by removed requests
|
||
self.input_batch.condense()
|
||
# Allow attention backend to reorder the batch, potentially
|
||
self._may_reorder_batch(scheduler_output)
|
||
# Refresh batch metadata with any pending updates.
|
||
self.input_batch.refresh_metadata()
|
||
|
||
# @trace_time('_prepare_inputs')
|
||
def _prepare_inputs(
|
||
self, scheduler_output: "SchedulerOutput"
|
||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||
Optional[SpecDecodeMetadata], np.ndarray,
|
||
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
|
||
Optional[torch.Tensor]]:
|
||
"""
|
||
:return: tuple[
|
||
attn_metadata: layer-to-attention_metadata mapping,
|
||
logits_indices, spec_decode_metadata
|
||
]
|
||
"""
|
||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||
assert total_num_scheduled_tokens > 0
|
||
num_reqs = self.input_batch.num_reqs
|
||
assert num_reqs > 0
|
||
|
||
# OPTIMIZATION: Start copying the block table first.
|
||
# This way, we can overlap the copy with the following CPU operations.
|
||
self.input_batch.block_table.commit_block_table(num_reqs) # copy to gpu max_model_len//block_size
|
||
|
||
# Get the number of scheduled tokens for each request.
|
||
req_ids = self.input_batch.req_ids
|
||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||
max_num_scheduled_tokens = max(tokens)
|
||
|
||
# Get request indices.
|
||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||
num_scheduled_tokens)
|
||
|
||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||
num_scheduled_tokens)
|
||
|
||
# Get positions.
|
||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||
arange,
|
||
out=positions_np)
|
||
|
||
# Calculate M-RoPE positions.
|
||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||
if self.uses_mrope:
|
||
self._calc_mrope_positions(scheduler_output)
|
||
|
||
# Get token indices.
|
||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||
# where M is the max_model_len.
|
||
token_indices = (positions_np +
|
||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||
token_indices_tensor = torch.from_numpy(token_indices)
|
||
|
||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||
# because torch.index_select is much faster than np.take for large
|
||
# tensors.
|
||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||
0,
|
||
token_indices_tensor,
|
||
out=self.input_ids.cpu[:total_num_scheduled_tokens])
|
||
if self.enable_prompt_embeds:
|
||
is_token_ids = self.input_batch.is_token_ids.flatten()
|
||
torch.index_select(
|
||
is_token_ids,
|
||
0,
|
||
token_indices_tensor,
|
||
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
|
||
|
||
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
|
||
# the InputBatch, we need to fill in the prompt embeds into the expected
|
||
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
|
||
if self.input_batch.req_prompt_embeds:
|
||
output_idx = 0
|
||
deepstack_input_embeds_all_seqs = []
|
||
for req_idx in range(num_reqs):
|
||
num_sched = num_scheduled_tokens[req_idx]
|
||
|
||
# Skip if this request doesn't have embeddings
|
||
if req_idx not in self.input_batch.req_prompt_embeds:
|
||
output_idx += num_sched
|
||
continue
|
||
|
||
# Skip if no tokens scheduled
|
||
if num_sched <= 0:
|
||
output_idx += num_sched
|
||
continue
|
||
|
||
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
|
||
deepstack_input_embeds = None
|
||
if self.input_batch.req_deepstack_input_embeds:
|
||
deepstack_input_embeds = self.input_batch.req_deepstack_input_embeds[req_idx]
|
||
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
|
||
|
||
# Skip if trying to read beyond available embeddings
|
||
if start_pos >= req_embeds.shape[0]:
|
||
output_idx += num_sched
|
||
continue
|
||
|
||
# Copy available embeddings
|
||
end_pos = start_pos + num_sched
|
||
actual_end = min(end_pos, req_embeds.shape[0])
|
||
actual_num_sched = actual_end - start_pos
|
||
|
||
if actual_num_sched > 0: # for prefill stage
|
||
self.inputs_embeds.cpu[output_idx:output_idx +
|
||
actual_num_sched].copy_(
|
||
req_embeds[start_pos:actual_end]
|
||
)
|
||
if deepstack_input_embeds is not None:
|
||
deepstack_input_embeds_all_seqs.append(deepstack_input_embeds)
|
||
|
||
output_idx += num_sched
|
||
if len(deepstack_input_embeds_all_seqs) > 0: # for multi-batch
|
||
self.deepstack_input_embeds = torch.concatenate(deepstack_input_embeds_all_seqs, -2)
|
||
|
||
self.input_batch.block_table.compute_slot_mapping(
|
||
req_indices, positions_np)
|
||
self.input_batch.block_table.commit_slot_mapping(
|
||
total_num_scheduled_tokens) # copy to gpu
|
||
|
||
# Prepare the attention metadata.
|
||
self.query_start_loc.np[0] = 0
|
||
self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
|
||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||
# like FlashAttention requires that
|
||
self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1])
|
||
|
||
num_computed_tokens = scheduler_output.scheduled_cached_reqs.num_computed_tokens
|
||
is_all_decode = len(num_computed_tokens) > 0 and all([i > 0 for i in num_computed_tokens])
|
||
|
||
if not is_all_decode:
|
||
self.query_start_loc.copy_to_gpu()
|
||
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
||
else:
|
||
query_start_loc = self.query_start_loc.np[:num_reqs + 1].tolist()
|
||
|
||
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
||
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||
num_tokens_unpadded)
|
||
uniform_decode = \
|
||
(max_num_scheduled_tokens == self.uniform_decode_query_len) and \
|
||
(total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||
ubatch_slices, num_tokens_after_padding = \
|
||
ubatch_split(num_scheduled_tokens,
|
||
num_tokens_unpadded,
|
||
num_tokens_padded,
|
||
uniform_decode=uniform_decode,
|
||
vllm_config=self.vllm_config)
|
||
|
||
|
||
# patch here to seq_lengs to list v0.11.0
|
||
# self.seq_lens.np[:num_reqs] = (
|
||
# self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||
# num_scheduled_tokens)
|
||
# # Fill unused with 0 for full cuda graph mode.
|
||
# self.seq_lens.np[num_reqs:].fill(0)
|
||
# self.seq_lens.copy_to_gpu()
|
||
# seq_lens = self.seq_lens.gpu[:num_reqs]
|
||
# max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||
# self.seq_lens.np[:num_reqs] = (
|
||
# self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||
# num_scheduled_tokens)
|
||
|
||
self.seq_lens[:num_reqs] = [int(computed_token) + int(scheduled_token) for computed_token, scheduled_token in zip(self.input_batch.num_computed_tokens_cpu[:num_reqs], num_scheduled_tokens)]
|
||
seq_lens = self.seq_lens[:num_reqs]
|
||
max_seq_len = max(seq_lens) # seq_lens is list
|
||
|
||
# num_tokens = [
|
||
# self.requests[r].num_tokens for r in self.input_batch.req_ids
|
||
# ]
|
||
# num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||
|
||
# Record the index of requests that should not be sampled,
|
||
# so that we could clear the sampled tokens before returning
|
||
# discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
|
||
|
||
# discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||
# self.num_discarded_requests = len(discard_request_indices)
|
||
# self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||
# discard_request_indices)
|
||
|
||
# self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||
|
||
# Copy the tensors to the GPU.
|
||
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
||
|
||
if self.uses_mrope:
|
||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
||
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||
non_blocking=True)
|
||
else:
|
||
# Common case (1D positions)
|
||
self.positions.copy_to_gpu(total_num_scheduled_tokens) # copy position
|
||
|
||
use_spec_decode = len(
|
||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||
|
||
if not use_spec_decode:
|
||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||
# partial requests. While we should not sample any token
|
||
# from these partial requests, we do so for simplicity.
|
||
# We will ignore the sampled tokens from the partial requests.
|
||
# TODO: Support prompt logprobs.
|
||
# logits_indices = query_start_loc[1:] - 1
|
||
if not is_all_decode:
|
||
logits_indices = query_start_loc[1:] - 1
|
||
else:
|
||
logits_indices = None
|
||
num_draft_tokens = None
|
||
spec_decode_metadata = None
|
||
else:
|
||
# Get the number of draft tokens for each request.
|
||
# Iterate over the dictionary rather than all requests since not all
|
||
# requests have draft tokens.
|
||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||
# For chunked prefills, use -1 as mask rather than 0, as guided
|
||
# decoding may rollback speculative tokens.
|
||
# num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
|
||
for req_id, draft_token_ids in (
|
||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||
# num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
|
||
# self.input_batch.num_computed_tokens_cpu[req_idx]
|
||
# >= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||
num_draft_tokens, cu_num_tokens) # copy
|
||
logits_indices = spec_decode_metadata.logits_indices
|
||
|
||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||
# self.num_decode_draft_tokens.np[:
|
||
# num_reqs] = num_decode_draft_tokens
|
||
# self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||
# self.num_decode_draft_tokens.copy_to_gpu() #
|
||
|
||
|
||
# logits_indices_padded = None
|
||
# if self.cache_config.kv_sharing_fast_prefill:
|
||
# logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||
# logits_indices)
|
||
|
||
attn_metadata: PerLayerAttnMetadata = {}
|
||
if ubatch_slices is not None:
|
||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||
|
||
# Used in the below loop.
|
||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||
seq_lens_cpu = self.seq_lens#.cpu[:num_reqs]
|
||
num_computed_tokens_cpu = (
|
||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||
spec_decode_common_attn_metadata = None
|
||
if use_spec_decode:
|
||
self.num_accepted_tokens.np[:num_reqs] = (
|
||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||
# self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||
# self.num_accepted_tokens.copy_to_gpu()
|
||
|
||
# Prepare the attention metadata for each KV cache group and make layers
|
||
# in the same group share the same metadata.
|
||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||
self.kv_cache_config.kv_cache_groups):
|
||
encoder_seq_lens = self._get_encoder_seq_lens(
|
||
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs)
|
||
|
||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||
EncoderOnlyAttentionSpec):
|
||
# Encoder-only layers do not have KV cache, so we need to
|
||
# create a dummy block table and slot mapping for them.
|
||
blk_table_tensor = torch.zeros(
|
||
(num_reqs, 1),
|
||
dtype=torch.int32,
|
||
device=self.device,
|
||
)
|
||
slot_mapping = torch.zeros(
|
||
(total_num_scheduled_tokens, ),
|
||
dtype=torch.int32, #int64 to int32
|
||
device=self.device,
|
||
)
|
||
num_common_prefix_blocks = 0
|
||
else:
|
||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||
# blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
||
blk_table_tensor = blk_table.block_table.gpu # slice first num_reqs move to metadata_build
|
||
slot_mapping = blk_table.slot_mapping.gpu[:
|
||
total_num_scheduled_tokens] # may copy
|
||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||
# graph mode.
|
||
# blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(
|
||
# -1)
|
||
num_common_prefix_blocks = (
|
||
scheduler_output.
|
||
num_common_prefix_blocks[kv_cache_group_id])
|
||
common_attn_metadata = CommonAttentionMetadata(
|
||
query_start_loc=query_start_loc,
|
||
query_start_loc_cpu=query_start_loc_cpu,
|
||
seq_lens=seq_lens,
|
||
seq_lens_cpu=seq_lens_cpu,
|
||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||
num_reqs=num_reqs,
|
||
num_actual_tokens=total_num_scheduled_tokens,
|
||
max_query_len=max_num_scheduled_tokens,
|
||
max_seq_len=max_seq_len,
|
||
block_table_tensor=blk_table_tensor,
|
||
slot_mapping=slot_mapping,
|
||
logits_indices_padded=None, #logits_indices_padded,
|
||
num_logits_indices=None, #logits_indices.size(0),
|
||
causal=True,
|
||
encoder_seq_lens=encoder_seq_lens,
|
||
)
|
||
|
||
if (self.speculative_config
|
||
and spec_decode_common_attn_metadata is None):
|
||
if isinstance(self.drafter, EagleProposer):
|
||
if (self.drafter.attn_layer_names[0]
|
||
in kv_cache_group_spec.layer_names):
|
||
spec_decode_common_attn_metadata = common_attn_metadata
|
||
else:
|
||
spec_decode_common_attn_metadata = common_attn_metadata
|
||
|
||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||
# Prepare for cascade attention if enabled & beneficial.
|
||
common_prefix_len = 0
|
||
builder = attn_group.get_metadata_builder()
|
||
if self.cascade_attn_enabled:
|
||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||
num_scheduled_tokens,
|
||
num_common_prefix_blocks,
|
||
attn_group.kv_cache_spec,
|
||
builder,
|
||
)
|
||
|
||
extra_attn_metadata_args = {}
|
||
# if use_spec_decode and isinstance(builder,
|
||
# GDNAttentionMetadataBuilder):
|
||
# extra_attn_metadata_args = dict(
|
||
# num_accepted_tokens=self.num_accepted_tokens.
|
||
# gpu[:num_reqs],
|
||
# num_decode_draft_tokens_cpu=self.
|
||
# num_decode_draft_tokens.cpu[:num_reqs],
|
||
# )
|
||
|
||
if ubatch_slices is not None:
|
||
common_attn_metadata_list = split_attn_metadata(
|
||
ubatch_slices, common_attn_metadata)
|
||
for ubid, common_attn_metadata in enumerate(
|
||
common_attn_metadata_list):
|
||
attn_metadata_i = (attn_group.get_metadata_builder(
|
||
ubatch_id=ubid).build(
|
||
common_prefix_len=common_prefix_len,
|
||
common_attn_metadata=common_attn_metadata))
|
||
for layer_name in kv_cache_group_spec.layer_names:
|
||
assert type(attn_metadata) is list
|
||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||
else:
|
||
assert isinstance(attn_metadata, dict)
|
||
attn_metadata_i = builder.build(
|
||
common_prefix_len=common_prefix_len,
|
||
common_attn_metadata=common_attn_metadata,
|
||
**extra_attn_metadata_args)
|
||
for layer_name in attn_group.layer_names:
|
||
attn_metadata[layer_name] = attn_metadata_i
|
||
|
||
# Hot-Swap lora model
|
||
if self.lora_config:
|
||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||
|
||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||
max_num_scheduled_tokens, ubatch_slices,
|
||
num_tokens_after_padding)
|
||
|
||
def _preprocess(
|
||
self,
|
||
scheduler_output: "SchedulerOutput",
|
||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||
ubatch_slices: Optional[UBatchSlices] = None,
|
||
num_tokens_after_padding: Optional[torch.Tensor] = None,
|
||
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
|
||
Optional[torch.Tensor], torch.Tensor,
|
||
Optional[IntermediateTensors], dict[str, Any]]:
|
||
|
||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||
if ubatch_slices:
|
||
assert num_tokens_after_padding is not None
|
||
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
|
||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||
elif ubatch_slices is None:
|
||
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
||
num_pad, num_tokens_after_padding = self.get_dp_padding(
|
||
num_input_tokens)
|
||
num_input_tokens += num_pad
|
||
|
||
deepstack_input_embeds = None
|
||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||
# modal outputs after that to ensure the correct order
|
||
if (self.supports_mm_inputs and get_pp_group().is_first_rank
|
||
and not self.model_config.is_encoder_decoder):
|
||
# Run the multimodal encoder if any.
|
||
self._execute_mm_encoder(scheduler_output)
|
||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||
|
||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||
# embeddings), we always use embeddings (rather than token ids)
|
||
# as input to the multimodal model, even when the input is text.
|
||
inputs_embeds_scheduled = self.model.get_input_embeddings(
|
||
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
|
||
multimodal_embeddings=mm_embeds or None,
|
||
)
|
||
|
||
# # TODO(woosuk): Avoid the copy. Optimize.
|
||
# self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(
|
||
# inputs_embeds_scheduled)
|
||
|
||
input_ids = None
|
||
# inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||
|
||
# 完整vl模型,直接从 get_input_embeddings 拿到 device 上的 tensor 直接传, 不需要用self.inputs_embeds.gpu
|
||
# enable_prompt_embeds 输入是 cpu 的 input_embeddings 才会需要拷贝
|
||
inputs_embeds = inputs_embeds_scheduled
|
||
|
||
model_kwargs = {
|
||
**self._init_model_kwargs(num_scheduled_tokens),
|
||
**self._extract_mm_kwargs(scheduler_output),
|
||
}
|
||
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
||
# Get the input embeddings for the tokens that are not input embeds,
|
||
# then put them into the appropriate positions.
|
||
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||
# enabled, (a) not all requests will use prompt embeds, and (b)
|
||
# after the initial prompt is processed, the rest of the generated
|
||
# tokens will be token ids, it is not desirable to have the
|
||
# embedding layer outside of the CUDA graph all the time. The v0
|
||
# engine avoids this by "double compiling" the CUDA graph, once
|
||
# with input_ids and again with inputs_embeds, for all num_tokens.
|
||
# If a batch only has token ids, then including the embedding layer
|
||
# in the CUDA graph will be more performant (like in the else case
|
||
# below).
|
||
|
||
# patch here to opt embedding model
|
||
# token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \
|
||
# .nonzero(as_tuple=False) \
|
||
# .squeeze(1)
|
||
# Some tokens ids may need to become embeds
|
||
# if token_ids_idx.numel() > 0:
|
||
# token_ids = self.input_ids.gpu[token_ids_idx]
|
||
# tokens_to_embeds = self.model.get_input_embeddings(
|
||
# input_ids=token_ids)
|
||
# self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
|
||
|
||
# inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||
|
||
token_ids_idx = self.is_token_ids.cpu[:num_scheduled_tokens] \
|
||
.nonzero(as_tuple=False) \
|
||
.squeeze(1)
|
||
if token_ids_idx.numel() > 0:
|
||
inputs_embeds = self.model.get_input_embeddings(
|
||
input_ids=self.input_ids.gpu[:num_input_tokens])
|
||
else:
|
||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||
|
||
if self.deepstack_input_embeds is not None:
|
||
# modelruner preprocess deepstack_input_embeds copy to vacc
|
||
self.deepstack_input_embeds = self.deepstack_input_embeds.to(self.device)
|
||
deepstack_input_embeds = self.deepstack_input_embeds
|
||
#TODO for dict
|
||
# for key in self.deepstack_input_embeds:
|
||
# self.deepstack_input_embeds[key] = self.deepstack_input_embeds[key].to(self.device)
|
||
|
||
|
||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||
input_ids = None
|
||
else:
|
||
# For text-only models, we use token ids as input.
|
||
# While it is possible to use embeddings as input just like the
|
||
# multimodal models, it is not desirable for performance since
|
||
# then the embedding layer is not included in the CUDA graph.
|
||
input_ids = self.input_ids.gpu[:num_input_tokens]
|
||
inputs_embeds = None
|
||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||
if self.uses_mrope:
|
||
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
||
else:
|
||
positions = self.positions.gpu[:num_input_tokens]
|
||
|
||
if get_pp_group().is_first_rank:
|
||
intermediate_tensors = None
|
||
else:
|
||
self.intermediate_tensors = intermediate_tensors
|
||
# intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||
# num_input_tokens, intermediate_tensors, True)
|
||
|
||
if (self.model_config.is_encoder_decoder
|
||
and scheduler_output.scheduled_encoder_inputs):
|
||
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
|
||
model_kwargs.update(encoder_inputs)
|
||
if deepstack_input_embeds is not None:
|
||
model_kwargs['deepstack_input_embeds'] = deepstack_input_embeds
|
||
|
||
return (
|
||
num_scheduled_tokens,
|
||
num_input_tokens,
|
||
num_tokens_after_padding,
|
||
input_ids,
|
||
inputs_embeds,
|
||
positions,
|
||
intermediate_tensors,
|
||
model_kwargs,
|
||
)
|
||
|
||
|
||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||
logger.info("Starting to load model %s...", self.model_config.model)
|
||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||
time_before_load = time.perf_counter()
|
||
model_loader = get_model_loader(self.load_config)
|
||
logger.info("Loading model from scratch...")
|
||
self.model = model_loader.load_model(
|
||
vllm_config=self.vllm_config, model_config=self.model_config)
|
||
if self.lora_config:
|
||
self.model = self.load_lora_model(self.model,
|
||
self.model_config,
|
||
self.scheduler_config,
|
||
self.lora_config,
|
||
self.device)
|
||
if hasattr(self, "drafter"):
|
||
logger.info("Loading drafter model...")
|
||
# self.drafter.load_model(self.model)
|
||
draft_model_eagle_load_model(self.drafter, self.model)
|
||
if self.use_aux_hidden_state_outputs:
|
||
self.model.set_aux_hidden_state_layers(
|
||
self.model.get_eagle3_aux_hidden_state_layers())
|
||
time_after_load = time.perf_counter()
|
||
self.model_memory_usage = m.consumed_memory
|
||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||
self.model_memory_usage / GiB_bytes,
|
||
time_after_load - time_before_load)
|
||
register_module_trace(self.model)
|
||
register_module_trace(self.sampler)
|
||
|
||
def _sample(
|
||
self, logits: Optional[torch.Tensor],
|
||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||
is_first_calculate
|
||
) -> SamplerOutput:
|
||
# Sample the next token and get logprobs if needed.
|
||
sampling_metadata = self.input_batch.sampling_metadata
|
||
if spec_decode_metadata is None:
|
||
sampler_output = self.sampler(
|
||
logits=logits,
|
||
sampling_metadata=sampling_metadata,
|
||
is_first_calculate=is_first_calculate,
|
||
req_ids = self.input_batch.req_ids
|
||
)
|
||
else:
|
||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||
# creates a new tensor with separate storage from the original
|
||
# logits tensor. This means any in-place operations on bonus_logits
|
||
# won't affect the original logits tensor.
|
||
|
||
# bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] #
|
||
# if bonus_logits_indices is int32, int32 => int64, copy + index_tensor_out
|
||
# copy + index_tensor_out => new + index_tensor_out by use torch.ops.aten.index
|
||
|
||
# logits = logits.to(torch.float32)
|
||
bonus_logits_indices = spec_decode_metadata.bonus_logits_indices
|
||
if isinstance(bonus_logits_indices, np.ndarray):
|
||
if bonus_logits_indices.shape[0] == 1:
|
||
bonus_logits = logits[bonus_logits_indices[0] : bonus_logits_indices[0] + 1] # only cpoy
|
||
else:
|
||
# bonus_logits = logits[bonus_logits_indices] # cpoy + silce
|
||
|
||
bonus_logits_indices = torch.from_numpy(bonus_logits_indices) # new + slice
|
||
size = len(spec_decode_metadata.bonus_logits_indices.shape) #
|
||
bonus_logits = logits.new_empty(spec_decode_metadata.bonus_logits_indices.shape + logits.shape[size:] )
|
||
torch.ops.aten.index(logits, [bonus_logits_indices], out=bonus_logits)
|
||
else:
|
||
size = len(spec_decode_metadata.bonus_logits_indices.shape)
|
||
bonus_logits = logits.new_empty(spec_decode_metadata.bonus_logits_indices.shape + logits.shape[size:] )
|
||
torch.ops.aten.index(logits, [spec_decode_metadata.bonus_logits_indices], out=bonus_logits)
|
||
|
||
sampler_output = self.sampler(
|
||
logits=bonus_logits,
|
||
sampling_metadata=sampling_metadata,
|
||
is_first_calculate=is_first_calculate,
|
||
req_ids = self.input_batch.req_ids
|
||
)
|
||
bonus_token_ids = sampler_output.sampled_token_ids
|
||
|
||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||
# separate storage from the original `logits` tensor. Therefore,
|
||
# it is safe to update `target_logits` in place.
|
||
# target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||
target_logits_indices = spec_decode_metadata.target_logits_indices
|
||
if isinstance(target_logits_indices, np.ndarray):
|
||
if target_logits_indices.shape[0] == 1:
|
||
target_logits = logits[target_logits_indices[0] : target_logits_indices[0] + 1] # only copy
|
||
else:
|
||
target_logits_indices = torch.from_numpy(target_logits_indices)
|
||
size = len(spec_decode_metadata.target_logits_indices.shape)
|
||
target_logits = logits.new_empty(spec_decode_metadata.target_logits_indices.shape + logits.shape[size:] )
|
||
torch.ops.aten.index(logits, [target_logits_indices], out=target_logits)
|
||
# target_logits = logits[target_logits_indices]
|
||
else:
|
||
size = len(spec_decode_metadata.target_logits_indices.shape)
|
||
target_logits = logits.new_empty(spec_decode_metadata.target_logits_indices.shape + logits.shape[size:] )
|
||
torch.ops.aten.index(logits, [spec_decode_metadata.target_logits_indices], out=target_logits)
|
||
|
||
output_token_ids = self.rejection_sampler(
|
||
spec_decode_metadata,
|
||
None, # draft_probs
|
||
target_logits,
|
||
bonus_token_ids,
|
||
sampling_metadata,
|
||
)
|
||
sampler_output.sampled_token_ids = output_token_ids
|
||
self._update_states_after_model_execute(output_token_ids)
|
||
|
||
return sampler_output
|
||
|
||
def _pool(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
num_scheduled_tokens: int,
|
||
num_scheduled_tokens_np: np.ndarray,
|
||
) -> ModelRunnerOutput:
|
||
assert self.input_batch.num_reqs ==\
|
||
len(self.input_batch.pooling_params), \
|
||
"Either all or none of the requests in" \
|
||
" a batch must be pooling request"
|
||
|
||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||
device=hidden_states.device)
|
||
seq_lens_cpu = self.seq_lens[:self.input_batch.num_reqs]
|
||
|
||
model = cast(VllmModelForPooling, self.model)
|
||
raw_pooler_output: PoolerOutput = model.pooler(
|
||
hidden_states=hidden_states,
|
||
pooling_metadata=pooling_metadata,
|
||
)
|
||
raw_pooler_output = json_map_leaves(
|
||
lambda x: x.to("cpu", non_blocking=True),
|
||
raw_pooler_output,
|
||
)
|
||
self._sync_device()
|
||
|
||
pooler_output: list[Optional[torch.Tensor]] = []
|
||
for raw_output, seq_len, prompt_len in zip(
|
||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||
|
||
output = raw_output if seq_len == prompt_len else None
|
||
pooler_output.append(output)
|
||
|
||
return ModelRunnerOutput(
|
||
req_ids=self.input_batch.req_ids,
|
||
req_id_to_index=self.input_batch.req_id_to_index,
|
||
sampled_token_ids=[],
|
||
logprobs=None,
|
||
prompt_logprobs_dict={},
|
||
pooler_output=pooler_output,
|
||
)
|
||
|
||
@torch.inference_mode()
|
||
def execute_model(
|
||
self,
|
||
scheduler_output: "SchedulerOutput",
|
||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||
with record_function_or_nullcontext("Preprocess"):
|
||
with self.synchronize_input_prep():
|
||
# Update persistent batch states.
|
||
self._update_states(scheduler_output)
|
||
|
||
if not scheduler_output.total_num_scheduled_tokens:
|
||
if not has_kv_transfer_group():
|
||
# Return empty ModelRunnerOutput if no work to do.
|
||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||
return self.kv_connector_no_forward(
|
||
scheduler_output, self.vllm_config)
|
||
if self.cache_config.kv_sharing_fast_prefill:
|
||
assert not self.input_batch.num_prompt_logprobs, (
|
||
"--kv-sharing-fast-prefill produces incorrect "
|
||
"logprobs for prompt tokens, tokens, please disable "
|
||
"it when the requests need prompt logprobs")
|
||
|
||
# Prepare the decoder inputs.
|
||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||
max_query_len, ubatch_slices, num_tokens_after_padding
|
||
) = self._prepare_inputs(scheduler_output)
|
||
|
||
if attn_metadata.items().__iter__().__next__()[1].prefill_metadata is not None:
|
||
try:
|
||
torch.vacc.empty_cache()
|
||
except Exception as e:
|
||
logger.warn("vacc empty cache skiping...")
|
||
|
||
(
|
||
num_scheduled_tokens,
|
||
num_input_tokens,
|
||
num_tokens_across_dp,
|
||
input_ids,
|
||
inputs_embeds,
|
||
positions,
|
||
intermediate_tensors,
|
||
model_kwargs,
|
||
) = self._preprocess(scheduler_output, intermediate_tensors,
|
||
ubatch_slices, num_tokens_after_padding)
|
||
|
||
uniform_decode = (max_query_len
|
||
== self.uniform_decode_query_len) and (
|
||
num_scheduled_tokens
|
||
== self.input_batch.num_reqs * max_query_len)
|
||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||
uniform_decode=uniform_decode)
|
||
cudagraph_runtime_mode, batch_descriptor = \
|
||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||
|
||
# This is currently to get around the assert in the DPMetadata
|
||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||
if ubatch_slices is not None:
|
||
num_input_tokens = ubatch_slices[0].num_tokens
|
||
|
||
# Run the model.
|
||
# Use persistent buffers for CUDA graphs.
|
||
with (set_forward_context(
|
||
attn_metadata,
|
||
self.vllm_config,
|
||
num_tokens=num_input_tokens,
|
||
num_tokens_across_dp=num_tokens_across_dp,
|
||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||
batch_descriptor=batch_descriptor,
|
||
ubatch_slices=ubatch_slices,
|
||
), record_function_or_nullcontext("Forward"),
|
||
self.maybe_get_kv_connector_output(scheduler_output) as
|
||
kv_connector_output):
|
||
model_output = self.model(
|
||
input_ids=input_ids,
|
||
positions=positions,
|
||
intermediate_tensors=intermediate_tensors,
|
||
inputs_embeds=inputs_embeds,
|
||
**model_kwargs,
|
||
)
|
||
|
||
with record_function_or_nullcontext("Postprocess"):
|
||
if self.use_aux_hidden_state_outputs:
|
||
# True when EAGLE 3 is used.
|
||
hidden_states, aux_hidden_states = model_output
|
||
else:
|
||
# Common case.
|
||
hidden_states = model_output
|
||
aux_hidden_states = None
|
||
|
||
if not self.broadcast_pp_output:
|
||
# Common case.
|
||
if not get_pp_group().is_last_rank:
|
||
# Return the intermediate tensors.
|
||
assert isinstance(hidden_states, IntermediateTensors)
|
||
hidden_states.kv_connector_output = kv_connector_output
|
||
return hidden_states
|
||
|
||
if self.is_pooling_model:
|
||
# Return the pooling output.
|
||
output = self._pool(hidden_states, num_scheduled_tokens,
|
||
num_scheduled_tokens_np)
|
||
output.kv_connector_output = kv_connector_output
|
||
return output
|
||
|
||
if logits_indices is None or logits_indices.shape[0] == hidden_states.shape[0]:
|
||
sample_hidden_states = hidden_states
|
||
else:
|
||
if isinstance(logits_indices, np.ndarray):
|
||
if logits_indices.shape[0] == 1:
|
||
sample_hidden_states = hidden_states[logits_indices[0] : logits_indices[0] + 1] # only cpoy
|
||
else:
|
||
logits_indices = torch.from_numpy(logits_indices)
|
||
size = len(logits_indices.shape)
|
||
new_shape = logits_indices.shape + hidden_states.shape[size:]
|
||
sample_hidden_states = hidden_states.new_empty(new_shape) # copy + slice
|
||
torch.ops.aten.index(hidden_states, [logits_indices], out=sample_hidden_states)
|
||
else:
|
||
# sample_hidden_states = hidden_states[logits_indices]
|
||
size = len(logits_indices.shape)
|
||
new_shape = logits_indices.shape + hidden_states.shape[size:] # new + slice
|
||
sample_hidden_states = hidden_states.new_empty(new_shape)
|
||
torch.ops.aten.index(hidden_states, [logits_indices], out=sample_hidden_states)
|
||
|
||
logits = self.model.compute_logits(sample_hidden_states)
|
||
else:
|
||
# Rare case.
|
||
assert not self.is_pooling_model
|
||
|
||
if not get_pp_group().is_last_rank:
|
||
all_gather_tensors = {
|
||
"residual":
|
||
not is_residual_scattered_for_sp(
|
||
self.vllm_config, num_input_tokens)
|
||
}
|
||
get_pp_group().send_tensor_dict(
|
||
hidden_states.tensors,
|
||
all_gather_group=get_tp_group(),
|
||
all_gather_tensors=all_gather_tensors)
|
||
logits = None
|
||
else:
|
||
|
||
# sample_hidden_states = hidden_states[logits_indices]
|
||
if logits_indices is None or logits_indices.shape[0] == hidden_states.shape[0]:
|
||
sample_hidden_states = hidden_states
|
||
else:
|
||
if isinstance(logits_indices, np.ndarray):
|
||
if logits_indices.shape[0] == 1:
|
||
sample_hidden_states = hidden_states[logits_indices[0] : logits_indices[0] + 1] # only cpoy
|
||
else:
|
||
logits_indices = torch.from_numpy(logits_indices) # new + slice
|
||
size = len(logits_indices.shape)
|
||
new_shape = logits_indices.shape + hidden_states.shape[size:]
|
||
sample_hidden_states = hidden_states.new_empty(new_shape)
|
||
torch.ops.aten.index(hidden_states, [logits_indices], out=sample_hidden_states)
|
||
else:
|
||
# sample_hidden_states = hidden_states[logits_indices] # copy + slice
|
||
size = len(logits_indices.shape)
|
||
new_shape = logits_indices.shape + hidden_states.shape[size:] # new + slice
|
||
sample_hidden_states = hidden_states.new_empty(new_shape)
|
||
torch.ops.aten.index(hidden_states, [logits_indices], out=sample_hidden_states)
|
||
|
||
logits = self.model.compute_logits(sample_hidden_states)
|
||
|
||
model_output_broadcast_data = {}
|
||
if logits is not None:
|
||
model_output_broadcast_data["logits"] = logits.contiguous()
|
||
|
||
model_output_broadcast_data = get_pp_group(
|
||
).broadcast_tensor_dict(model_output_broadcast_data,
|
||
src=len(get_pp_group().ranks) - 1)
|
||
assert model_output_broadcast_data is not None
|
||
logits = model_output_broadcast_data["logits"]
|
||
|
||
# Apply structured output bitmasks if present
|
||
if scheduler_output.grammar_bitmask is not None:
|
||
apply_grammar_bitmask(scheduler_output, self.input_batch,
|
||
logits, self.device)
|
||
|
||
with record_function_or_nullcontext("Sample"):
|
||
is_first_calculate = attn_metadata[list(attn_metadata.keys())[0]].num_prefills
|
||
sampler_output = self._sample(logits, spec_decode_metadata, is_first_calculate)
|
||
|
||
def propose_draft_token_ids(sampled_token_ids):
|
||
assert spec_decode_common_attn_metadata is not None
|
||
with record_function_or_nullcontext("Draft"):
|
||
self._draft_token_ids = self.propose_draft_token_ids(
|
||
scheduler_output,
|
||
sampled_token_ids,
|
||
self.input_batch.sampling_metadata,
|
||
hidden_states,
|
||
sample_hidden_states,
|
||
aux_hidden_states,
|
||
spec_decode_metadata,
|
||
spec_decode_common_attn_metadata,
|
||
)
|
||
|
||
use_padded_batch_for_eagle = self.speculative_config and \
|
||
self.speculative_config.use_eagle() and \
|
||
not self.speculative_config.disable_padded_drafter_batch
|
||
effective_drafter_max_model_len = self.max_model_len
|
||
if effective_drafter_max_model_len is None:
|
||
effective_drafter_max_model_len = self.model_config.max_model_len
|
||
if (self.speculative_config
|
||
and self.speculative_config.draft_model_config is not None
|
||
and self.speculative_config.draft_model_config.max_model_len
|
||
is not None):
|
||
effective_drafter_max_model_len = (
|
||
self.speculative_config.draft_model_config.max_model_len)
|
||
input_fits_in_drafter = spec_decode_common_attn_metadata and (
|
||
max(spec_decode_common_attn_metadata.seq_lens) +
|
||
self.speculative_config.num_speculative_tokens
|
||
<= effective_drafter_max_model_len)
|
||
if use_padded_batch_for_eagle and input_fits_in_drafter:
|
||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||
|
||
with record_function_or_nullcontext("Bookkeep"):
|
||
(
|
||
num_nans_in_logits,
|
||
logprobs_lists,
|
||
valid_sampled_token_ids,
|
||
prompt_logprobs_dict,
|
||
req_ids_output_copy,
|
||
req_id_to_index_output_copy,
|
||
invalid_req_indices,
|
||
) = self._bookkeeping_sync(scheduler_output, sampler_output,
|
||
logits, hidden_states,
|
||
num_scheduled_tokens)
|
||
|
||
if (self.speculative_config and not use_padded_batch_for_eagle
|
||
and input_fits_in_drafter):
|
||
# ngram and other speculative decoding methods use the sampled
|
||
# tokens on the CPU, so they are run after bookkeeping.
|
||
propose_draft_token_ids(valid_sampled_token_ids)
|
||
|
||
with record_function_or_nullcontext("EPLB"):
|
||
self.eplb_step()
|
||
|
||
output = ModelRunnerOutput(
|
||
req_ids=req_ids_output_copy,
|
||
req_id_to_index=req_id_to_index_output_copy,
|
||
sampled_token_ids=valid_sampled_token_ids,
|
||
logprobs=logprobs_lists,
|
||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||
pooler_output=[],
|
||
kv_connector_output=kv_connector_output,
|
||
num_nans_in_logits=num_nans_in_logits,
|
||
)
|
||
|
||
if not self.use_async_scheduling:
|
||
return output
|
||
|
||
return AsyncGPUModelRunnerOutput(
|
||
model_runner_output=output,
|
||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||
invalid_req_indices=invalid_req_indices,
|
||
async_output_copy_stream=self.async_output_copy_stream,
|
||
)
|
||
|
||
def propose_draft_token_ids(
|
||
self,
|
||
scheduler_output: "SchedulerOutput",
|
||
sampled_token_ids: Union[torch.Tensor, list[list[int]]],
|
||
sampling_metadata: SamplingMetadata,
|
||
hidden_states: torch.Tensor,
|
||
sample_hidden_states: torch.Tensor,
|
||
aux_hidden_states: Optional[list[torch.Tensor]],
|
||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||
common_attn_metadata: CommonAttentionMetadata,
|
||
) -> Union[list[list[int]], torch.Tensor]:
|
||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||
if self.speculative_config.method == "ngram":
|
||
assert isinstance(sampled_token_ids, list)
|
||
assert isinstance(self.drafter, NgramProposer)
|
||
draft_token_ids = self.drafter.propose(
|
||
sampled_token_ids, self.input_batch.req_ids,
|
||
self.input_batch.num_tokens_no_spec,
|
||
self.input_batch.token_ids_cpu,
|
||
self.input_batch.spec_decode_unsupported_reqs)
|
||
elif self.speculative_config.method == "medusa":
|
||
assert isinstance(sampled_token_ids, list)
|
||
assert isinstance(self.drafter, MedusaProposer)
|
||
|
||
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
||
# The input to the target model does not include draft tokens.
|
||
hidden_states = sample_hidden_states
|
||
else:
|
||
indices = []
|
||
offset = 0
|
||
assert spec_decode_metadata is not None
|
||
for num_draft, tokens in zip(
|
||
spec_decode_metadata.num_draft_tokens,
|
||
sampled_token_ids):
|
||
indices.append(offset + len(tokens) - 1)
|
||
offset += num_draft + 1
|
||
indices = torch.tensor(indices, device=self.device)
|
||
hidden_states = sample_hidden_states[indices]
|
||
|
||
draft_token_ids = self.drafter.propose(
|
||
target_hidden_states=hidden_states,
|
||
sampling_metadata=sampling_metadata,
|
||
)
|
||
elif self.speculative_config.use_eagle():
|
||
assert isinstance(self.drafter, EagleProposer)
|
||
|
||
if self.speculative_config.disable_padded_drafter_batch:
|
||
# When padded-batch is disabled, the sampled_token_ids should be
|
||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||
# request, with invalid requests having empty lists.
|
||
assert isinstance(sampled_token_ids, list), \
|
||
"sampled_token_ids should be a python list when" \
|
||
"padded-batch is disabled."
|
||
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||
sampled_token_ids, self.requests, self.input_batch,
|
||
scheduler_output.num_scheduled_tokens)
|
||
else:
|
||
# When using padded-batch, the sampled_token_ids should be
|
||
# the gpu tensor of sampled tokens for each request, of shape
|
||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||
# value -1.
|
||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||
"sampled_token_ids should be a torch.Tensor when" \
|
||
"padded-batch is enabled."
|
||
next_token_ids, valid_sampled_tokens_count = \
|
||
self.drafter.prepare_next_token_ids_padded(
|
||
common_attn_metadata,
|
||
sampled_token_ids,
|
||
self.requests,
|
||
self.input_batch,
|
||
self.discard_request_indices.gpu,
|
||
self.num_discarded_requests
|
||
)
|
||
|
||
if spec_decode_metadata is None:
|
||
token_indices_to_sample = None
|
||
# input_ids can be None for multimodal models.
|
||
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
||
# TODO(woosuk): Support M-RoPE.
|
||
target_positions = self.positions.gpu[:num_scheduled_tokens]
|
||
if self.use_aux_hidden_state_outputs:
|
||
assert aux_hidden_states is not None
|
||
target_hidden_states = torch.cat(
|
||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||
dim=-1)
|
||
else:
|
||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||
else:
|
||
if self.speculative_config.disable_padded_drafter_batch:
|
||
token_indices_to_sample = None
|
||
common_attn_metadata, token_indices =\
|
||
self.drafter.prepare_inputs(
|
||
common_attn_metadata,
|
||
sampled_token_ids,
|
||
spec_decode_metadata.num_draft_tokens)
|
||
else:
|
||
common_attn_metadata, token_indices, \
|
||
token_indices_to_sample =\
|
||
self.drafter.prepare_inputs_padded(
|
||
common_attn_metadata,
|
||
spec_decode_metadata,
|
||
valid_sampled_tokens_count)
|
||
|
||
# target_token_ids = self.input_ids.gpu[token_indices]
|
||
# # TODO(woosuk): Support M-RoPE.
|
||
# target_positions = self.positions.gpu[token_indices]
|
||
# print('token_indices', token_indices)
|
||
assert len(token_indices) > 0, len(token_indices)
|
||
if len(token_indices) > 1:
|
||
token_indices_vacc = torch.tensor(token_indices, dtype=torch.int32).to(self.input_ids.gpu.device, non_blocking=True)
|
||
target_token_ids = self.input_ids.gpu.new_empty(token_indices_vacc.shape + self.input_ids.gpu.shape[1:] )
|
||
torch.ops.aten.index(self.input_ids.gpu, [token_indices_vacc], out=target_token_ids)
|
||
else: # len(token_indices) == 1:
|
||
target_token_ids = self.input_ids.gpu[token_indices[0] : token_indices[0] + 1]
|
||
|
||
|
||
# TODO(woosuk): Support M-RoPE.
|
||
# target_positions = self.positions[token_indices]
|
||
if len(token_indices) == 1:
|
||
target_positions = self.positions.gpu[token_indices[0] : token_indices[0] + 1]
|
||
else:
|
||
target_positions = self.positions.gpu.new_empty(token_indices_vacc.shape + self.positions.gpu.shape[1:] )
|
||
torch.ops.aten.index(self.positions.gpu, [token_indices_vacc], out=target_positions)
|
||
|
||
if self.use_aux_hidden_state_outputs:
|
||
assert aux_hidden_states is not None
|
||
target_hidden_states = torch.cat(
|
||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||
else:
|
||
# target_hidden_states = hidden_states[token_indices]
|
||
if len(token_indices) == hidden_states.shape[0]: # all accepted no need slice
|
||
target_hidden_states = hidden_states
|
||
elif len(token_indices) == 1: # batch == 1 and noly 1 element: use copy instead of copy + slice
|
||
target_hidden_states = hidden_states[token_indices[0] : token_indices[0] + 1]
|
||
else:
|
||
target_hidden_states = hidden_states.new_empty(token_indices_vacc.shape + hidden_states.shape[1:])
|
||
torch.ops.aten.index(hidden_states, [token_indices_vacc], out=target_hidden_states)
|
||
|
||
mm_embeds = None
|
||
if self.supports_mm_inputs:
|
||
mm_embeds = self._gather_mm_embeddings(scheduler_output,
|
||
shift_computed_tokens=1)
|
||
|
||
draft_token_ids = self.drafter.propose(
|
||
target_token_ids=target_token_ids,
|
||
target_positions=target_positions,
|
||
target_hidden_states=target_hidden_states,
|
||
next_token_ids=next_token_ids,
|
||
last_token_indices=token_indices_to_sample,
|
||
sampling_metadata=sampling_metadata,
|
||
common_attn_metadata=common_attn_metadata,
|
||
mm_embeds=mm_embeds,
|
||
)
|
||
return draft_token_ids
|
||
|
||
|
||
def _calc_spec_decode_metadata(
|
||
self,
|
||
num_draft_tokens: np.ndarray,
|
||
cu_num_scheduled_tokens: np.ndarray,
|
||
) -> SpecDecodeMetadata:
|
||
# Inputs:
|
||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
||
# Outputs:
|
||
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
||
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
||
# 206, 207, 208]
|
||
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
||
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
||
|
||
# Compute the logits indices.
|
||
# [4, 1, 3, 1, 2]
|
||
num_sampled_tokens = num_draft_tokens + 1
|
||
|
||
# Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
|
||
# arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||
cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
|
||
num_sampled_tokens, cumsum_dtype=np.int32)
|
||
# Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||
logits_indices = np.repeat(
|
||
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||
# Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||
logits_indices += arange
|
||
|
||
# Compute the bonus logits indices.
|
||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||
|
||
# Compute the draft logits indices.
|
||
# cu_num_draft_tokens: [3, 3, 5, 5, 6]
|
||
# arange: [0, 1, 2, 0, 1, 0]
|
||
cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
|
||
num_draft_tokens, cumsum_dtype=np.int32)
|
||
# [0, 0, 0, 5, 5, 9]
|
||
target_logits_indices = np.repeat(
|
||
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
||
# [0, 1, 2, 5, 6, 9]
|
||
target_logits_indices += arange
|
||
|
||
# draft_token_ids = self.input_ids.gpu[logits_indices]
|
||
# draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||
|
||
indice = (target_logits_indices + 1).tolist()
|
||
if len(indice) == 1:
|
||
draft_token_ids = self.input_ids.gpu[indice[0] : indice[0] + 1] # vacc_copy
|
||
else:
|
||
draft_token_ids = self.input_ids.gpu[indice] # copy + copy + index_out
|
||
|
||
|
||
# TODO: Optimize the CPU -> GPU copy.
|
||
# cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||
# self.device, non_blocking=True)
|
||
# logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||
# non_blocking=True)
|
||
# target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
||
# self.device, non_blocking=True)
|
||
# bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
||
# self.device, non_blocking=True)
|
||
|
||
# Compute the draft token ids.
|
||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||
|
||
metadata = SpecDecodeMetadata(
|
||
draft_token_ids=draft_token_ids,
|
||
num_draft_tokens=num_draft_tokens.tolist(),
|
||
cu_num_draft_tokens=None, #cu_num_draft_tokens,
|
||
target_logits_indices=target_logits_indices,
|
||
bonus_logits_indices=bonus_logits_indices,
|
||
logits_indices=logits_indices,
|
||
)
|
||
return metadata
|
||
|
||
def warming_up_model(self) -> None:
|
||
logger.info("Warming up model for the compilation...")
|
||
# Only generate graph for the generic shape
|
||
with _set_global_compilation_settings(self.vllm_config):
|
||
self._dummy_run(max(16, self.max_num_reqs))
|
||
logger.info("Warming up done.")
|
||
|
||
def _init_device_properties(self) -> None:
|
||
pass
|
||
|
||
def _sync_device(self) -> None:
|
||
pass
|
||
|
||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||
"""
|
||
Update the order of requests in the batch based on the attention
|
||
backend's needs. For example, some attention backends (namely MLA) may
|
||
want to separate requests based on if the attention computation will be
|
||
compute-bound or memory-bound.
|
||
|
||
Args:
|
||
scheduler_output: The scheduler output.
|
||
"""
|
||
# Attention free models have zero kv_cache_groups, however models
|
||
# like Mamba are also attention free but use the kv_cache for
|
||
# keeping its internal state. This is why we check the number
|
||
# of kv_cache groups instead of solely checking
|
||
# for self.model_config.is_attention_free.
|
||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||
return
|
||
|
||
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||
raise ValueError("Multiple KVCacheGroups is not"
|
||
"currently supported with CPU model runner.")
|
||
|
||
# Guard against encoder-only / pooling models where `attn_groups`
|
||
# may be empty or lack the expected metadata_builder.
|
||
# Without this check, accessing `attn_groups[0][0]` would trigger
|
||
# an AssertionError on CPU backend.
|
||
if not hasattr(self, "attn_groups") or not self.attn_groups:
|
||
return
|
||
if not self.attn_groups[0]:
|
||
return
|
||
|
||
mb = getattr(self.attn_groups[0][0], "metadata_builder", None)
|
||
# from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||
from vllm_vacc.vllm.v1.attention.backends.vacc_attn import VACCAttentionMetadataBuilderV1
|
||
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
|
||
if isinstance(mb, list):
|
||
if not isinstance(mb[0], VACCAttentionMetadataBuilderV1):
|
||
return
|
||
mb[0].reorder_batch(self.input_batch, scheduler_output)
|
||
return
|
||
elif not isinstance(mb, VACCAttentionMetadataBuilderV1):
|
||
# Encoder-only / rerank models do not benefit from reordering,
|
||
# so we safely skip here.
|
||
return
|
||
|
||
# Safe path for decoder/attention-heavy models
|
||
mb.reorder_batch(self.input_batch, scheduler_output)
|
||
|
||
|
||
|
||
@contextmanager
|
||
def _set_global_compilation_settings(config: VllmConfig):
|
||
import torch._inductor.config
|
||
|
||
inductor_config = config.compilation_config.inductor_compile_config
|
||
try:
|
||
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
|
||
freezing_value = torch._inductor.config.freezing
|
||
if inductor_config.get("max_autotune", False):
|
||
torch._inductor.config.freezing = True
|
||
yield
|
||
finally:
|
||
torch._inductor.config.freezing = freezing_value
|