# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import List, Optional import torch from vllm.forward_context import set_forward_context from vllm.logger import logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner_base import (ModelRunnerBase, ModelRunnerInputBase, ModelRunnerWrapperBase) from vllm_ascend.attention.attention import AscendMetadata # A flag to enable debug prints for the updated input tensors # before each step. debug_advance_input = False # A flag to allow GPU advance step for draft model runner. # Set to False for debugging. allow_gpu_advance_step = True class TP1DraftModelRunner(ModelRunnerWrapperBase): """Specialized model runner for speculative decoding draft model. Since the draft model always execute k forward passes consecutively to generate k speculative tokens in a single speculative decoding step, we could get rid of most CPU-GPU synchronization and data transfer overheads by keeping model input and output tensors on GPU all the time. TODOs: 1. Currently supports only flash-attn, add support for other attn_backends. 2. Support TP > 1 (this requires some designs because we do not expect any broadcasting inside execute_model). """ def __init__(self, model_runner: ModelRunnerBase): if hasattr( model_runner, "return_hidden_states") and model_runner.return_hidden_states: raise ValueError( "return_hidden_states is not supported for TP1DraftModelRunner." ) super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): assert sampling_metadata.num_prompts == 0 assert len(sampling_metadata.seq_groups) == num_queries assert sampling_metadata.selected_token_indices.shape == ( num_queries, ) # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 # Verify that all sequences are decodes for i in range(num_queries): seq_group = sampling_metadata.seq_groups[i] assert seq_group.is_prompt is False # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple def _gpu_advance_step(self, model_input: ModelRunnerInputBase, last_output: SamplerOutput) -> ModelRunnerInputBase: # Currently, we expect "decode mode" only assert not model_input.is_prompt # Get num_seqs num_seqs = len(model_input.seq_lens) num_queries = len(model_input.query_lens) # Get output tokens GPU tensor sampled_token_ids = last_output.sampled_token_ids assert sampled_token_ids is not None # Update attn_metadata attn_metadata = model_input.attn_metadata assert isinstance(attn_metadata, AscendMetadata) attn_metadata.advance_step(model_input, sampled_token_ids, self.block_size, num_seqs, num_queries) # Update sampling_metadata sampling_metadata = model_input.sampling_metadata self._update_sampling_metadata(sampling_metadata, num_seqs, num_queries) # Create new input new_model_input = self._model_input_cls( input_tokens=model_input.input_tokens, input_positions=model_input.input_positions, attn_metadata=attn_metadata, seq_lens=attn_metadata.seq_lens, query_lens=model_input.query_lens, # Notes: If vllm_ascend supports LORA, we need to # add the following two params. # lora_mapping=model_input.lora_mapping, # lora_requests=model_input.lora_requests, multi_modal_kwargs=model_input.multi_modal_kwargs, sampling_metadata=model_input.sampling_metadata, is_prompt=False, ) # Ensure we skip CPU samples assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True # We can reuse sampling tensors since every decode iteration is the same new_model_input.sampling_metadata.reuse_sampling_tensors = True if debug_advance_input: logger.debug("NEW INPUT: ") logger.debug(" input_tokens = %s", new_model_input.input_tokens) logger.debug(" input_positions = %s", new_model_input.input_positions) logger.debug(" seq_lens = %d", new_model_input.seq_lens) logger.debug(" query_lens = %d", new_model_input.query_lens) logger.debug(" attn_metadata:") logger.debug(" seq_lens_tensor: %s", attn_metadata.seq_lens_tensor) logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) logger.debug(" block_tables: %s", attn_metadata.block_tables) return new_model_input def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): """Determines if draft_model_runner GPU multi-step can be used. Currently required conditions are: 1. Only decodes 2. Only flash-attn 3. No LORA 4. No prompt_adapter_config """ if not allow_gpu_advance_step: return False # We allow multi-step GPU only in decode mode for seq_group in execute_model_req.seq_group_metadata_list: if seq_group.is_prompt: return False # TODO: Add support for ASCEND when outer multi_step_worker # could work correct. if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): return False # TODO: Add support for LORA if self.lora_config: return False # TODO: Add soft-tuning prompt adapter support return not self.prompt_adapter_config def set_indices_of_seq_with_bonus_tokens(self, indices_of_seq_with_bonus_tokens): self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens @torch.inference_mode() def execute_model( self, model_input: ModelRunnerInputBase, kv_caches: List[torch.Tensor], previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, **kwargs, ) -> Optional[List[SamplerOutput]]: """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. Optimizations used: 1. Input tensors are updated on the GPU directly 2. Skips GPU=>CPU serialization of sampler outputs (we don't need them since we do batch expansion later that uses GPU outputs) 3. Reuses sampling tensors (since we run only decodes and they have a repeating sampling logic) """ # When num_steps == 1, we execute the fallback here for the GPU # advance_step, which runs prepare_inputs on CPU and for each spec # iteration invokes this function only once # (Look at multi-step-worker code) is_fallback = num_steps == 1 if not is_fallback: # Since we do not broadcast data inside execute_model anymore, # we need to figure out the best way to support TP > 1 in this # case, because we will at least need to broadcast the sampled # tokens to all workers. if not self.is_driver_worker: raise ValueError("TP1DraftModelRunner only supports TP=1.") # Sanity if self.lora_config is not None: raise ValueError("TP1DraftModelRunner has no support for LORA") if self.prompt_adapter_config is not None: raise ValueError("TP1DraftModelRunner has no support for " "prompt_adapter_config") if model_input.multi_modal_kwargs: raise ValueError( "TP1DraftModelRunner has no support for multi_modal_kwargs" ) else: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None self.set_active_prompt_adapters( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) self.attn_state.begin_forward(model_input) # Detect exec mode assert model_input.attn_metadata is not None if model_input.attn_metadata.num_prefills > 0: # In this case, execute_model(..) was called directly if num_steps > 1: raise ValueError( "execute_model(..) of draft_model_runner can be called " "directly only with a single-step prefill") else: # We can skip CPU samples for spec token generation. # (We do allow CPU samples for num_steps == 1 to support the # fallback case, where supports_gpu_multi_step(..) does not pass) model_input.sampling_metadata.skip_sampler_cpu_output = ( not is_fallback) model_executable = self.model hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} model_execute_kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} compute_logits_kwargs = {} # Run model if hasattr(self.model.config, "num_nextn_predict_layers"): # for DeepSeek MTP only to use the corresponding layer for # each step spec_step_idx = kwargs.get("spec_step_idx", step) model_execute_kwargs["spec_step_idx"] = spec_step_idx compute_logits_kwargs["spec_step_idx"] = spec_step_idx with set_forward_context(model_input.attn_metadata, self.vllm_config): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **model_execute_kwargs, ) # Compute the logits. logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata, **compute_logits_kwargs) if not self.is_driver_worker: return [] # Sample the next token. assert self.model_runner.sampler is not None output = self.model_runner.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) outputs.append(output) if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None # output.sampled_token_ids should be of shape (num_seqs, 1) nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape assert num_tokens_per_seq == 1 count = 0 for i in range(nums_seqs): bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ count] if i != bonus_seq_idx: # The following might cause a cpu->gpu sync # However, the performance impact is negligible as we # benchmarked on H100. output.sampled_token_ids[ i, :] = model_input.input_tokens[bonus_seq_idx] else: count += 1 # Prepare inputs for the next step if step != num_steps - 1: model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs