From c498cea22d1aeff9ddfa376db728c616e553628a Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Tue, 27 Jan 2026 22:27:01 +0800 Subject: [PATCH] [refactor] refactor excute_model and _dymmy_run method (#6043) ### What this PR does / why we need it? The structure of the `excute_model` and `_dymmy_run` methods in NPUModelRunner differs greatly from that in GPUModelRunner. Achieve alignment with GPUModelRunner: Split the `_prepare_inputs` method into `_prepare_inputs`, `_determine_batch_execution_and_padding`, `_build_attention_metadata`, and `_preprocess`. Modify `_generate_process_reqs_hidden_states` to `_model_forward`. Align the implementation of the `postprocess` phase **Related-RFC**: https://github.com/vllm-project/vllm-ascend/issues/5449 **Co-authored-by**: @zhenwenqi2024 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: gcanlin Signed-off-by: zhenwenqi2024 Co-authored-by: gcanlin Co-authored-by: zhenwenqi2024 --- tests/ut/worker/test_pcp_manager.py | 1 + vllm_ascend/attention/sfa_v1.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 1457 ++++++++++++++----------- vllm_ascend/worker/pcp_utils.py | 15 +- 4 files changed, 825 insertions(+), 650 deletions(-) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index 9a6779c1..3f5ea17a 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -123,6 +123,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, vllm_config = MagicMock() vllm_config.model_config = MagicMock() vllm_config.speculative_config.num_speculative_tokens = 0 + vllm_config.scheduler_config.max_num_seqs = 1000 pcp_manager = PCPManager(pcp_world_size=pcp_size, pcp_rank=0, diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 56730cc5..6f283429 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -170,7 +170,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" ) - + self.reorder_batch_threshold = self.decode_threshold self.attn_mask_builder = AttentionMaskBuilder(self.device) self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.enable_dsa_cp = enable_dsa_cp() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 23947740..3b5f2812 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,7 +24,7 @@ from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from dataclasses import dataclass from multiprocessing import Manager -from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union, TypeAlias, Tuple import numpy as np import torch @@ -33,6 +33,7 @@ import torch.nn as nn from vllm.attention.layer import Attention, MLAAttention from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer @@ -41,16 +42,16 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group, get_pcp_group, get_pp_group, get_tp_group) -from vllm.forward_context import get_forward_context +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.model_loader import get_model from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import LazyLoader -from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import DeviceMemoryProfiler -from vllm.v1.attention.backend import AttentionBackend, AttentionType # type: ignore +from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.selector import get_attn_backend # type: ignore @@ -61,8 +62,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheGroupSpec, KVCacheSpec, MambaSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - LogprobsLists, LogprobsTensors, ModelRunnerOutput, - SamplerOutput, ECConnectorOutput, + ECConnectorOutput, LogprobsLists, LogprobsTensors, + ModelRunnerOutput, SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -75,6 +76,10 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput, GPUModelRunner) from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.utils import AttentionGroup +from vllm.v1.worker.ubatch_utils import ( + UBatchSlices, + maybe_create_ubatch_slices, +) from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -121,6 +126,10 @@ import torch_npu # if true, allow tensor initialization and casting with internal format (e.g., NZ) torch.npu.config.allow_internal_format = True +AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict + if get_ascend_device_type() == AscendDeviceType._310P: torch_npu.npu.set_compile_mode(jit_compile=False) @@ -176,12 +185,13 @@ class ExecuteModelState(NamedTuple): scheduler_output: "SchedulerOutput" logits: torch.Tensor spec_decode_metadata: SpecDecodeMetadata | None + spec_decode_common_attn_metadata: AscendCommonAttentionMetadata | None hidden_states: torch.Tensor sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None - attn_metadata: dict[str, Any] + attn_metadata: "PerLayerAttnMetadata" positions: torch.Tensor - ec_connector_output: ECConnectorOutput | None + ec_connector_output: "ECConnectorOutput | None" class NPUModelRunner(GPUModelRunner): @@ -235,7 +245,7 @@ class NPUModelRunner(GPUModelRunner): self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) self.sampler = AscendSampler() - self.attn_state = None + self.attn_state: AscendAttentionState | None = None # Ascend-specific configurations self.ascend_config = get_ascend_config() @@ -354,6 +364,10 @@ class NPUModelRunner(GPUModelRunner): self.reorder_batch_threshold: int | None = None self.long_seq_metadata = None + @property + def use_cp(self) -> bool: + return self.pcp_size * self.dcp_size > 1 + def _init_device_properties(self) -> None: self.num_sms = None @@ -497,9 +511,15 @@ class NPUModelRunner(GPUModelRunner): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[dict[str, Any], np.ndarray, int, Optional[torch.Tensor], - torch.Tensor, Optional[SpecDecodeMetadata], int]: + num_scheduled_tokens: np.ndarray, + ) -> tuple[ + torch.Tensor, + SpecDecodeMetadata | None]: + """ + :return: tuple[ + logits_indices, spec_decode_metadata, + ] + """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -509,23 +529,17 @@ class NPUModelRunner(GPUModelRunner): # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + # Get the attention state. if not scheduler_output.scheduled_spec_decode_tokens: - num_valid_tokens = np.array(tokens, dtype=np.int32) + num_valid_tokens = num_scheduled_tokens else: num_valid_tokens = np.array([ - num_tokens - + scheduler_output.num_scheduled_tokens[i] - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for num_tokens, i in zip(tokens, req_ids) - ], - dtype=np.int32) - # Get the attention state. + for i in self.input_batch.req_ids + ], dtype=np.int32) attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) self.attn_state = attn_state # type: ignore @@ -534,6 +548,7 @@ class NPUModelRunner(GPUModelRunner): with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + self.with_prefill = with_prefill # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] @@ -583,59 +598,8 @@ class NPUModelRunner(GPUModelRunner): position_pcp[:total_num_scheduled_tokens], out=positions_np, ) - max_num_scheduled_tokens = max(tokens) - if (self.use_aclgraph and total_num_scheduled_tokens - <= self.cudagraph_batch_sizes[-1]): - # Add padding to the batch size. - if vllm_version_is('0.14.1'): - num_input_tokens = self.vllm_config.pad_for_cudagraph( - total_num_scheduled_tokens) - else: - num_input_tokens = self.cudagraph_dispatcher._bs_to_padded_graph_size[ - total_num_scheduled_tokens] - elif self.use_aclgraph and enable_sp(self.vllm_config): - # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, - # the model will fall back to running its FX graph in eager mode. - # In this case, when sequence parallelism is enabled, we need to pad tokens to align - # with tp_size because pad_size cannot be captured by the FX graph - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - num_input_tokens = math.ceil( - total_num_scheduled_tokens / tp_size) * tp_size - else: - # Eager mode. - num_input_tokens = total_num_scheduled_tokens self.query_lens = torch.from_numpy(num_scheduled_tokens) - # Get info across DP ranks. - (num_input_tokens, num_tokens_across_dp, - with_prefill) = self._sync_metadata_across_dp(num_input_tokens, - with_prefill) - self.with_prefill = with_prefill - - # Hot-Swap lora model - if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) - - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self._calc_mrope_positions(scheduler_output) - self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True, - ) - elif self.uses_xdrope_dim > 0: - self._calc_xdrope_positions(scheduler_output) - # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) - self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( - self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True, - ) - else: - # Common case (1D positions) - self.positions.copy_to_gpu(total_num_scheduled_tokens) - # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -701,7 +665,10 @@ class NPUModelRunner(GPUModelRunner): self.query_start_loc.np[0] = 0 self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens - self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + # NOTE: Due to the FIA operator limitation, here we pad so that hidden_states.shape[0] + # and self.query_start_loc[num_reqs_padded] are equal + self.query_start_loc.np[num_reqs + 1:] = (self.arange_np[1:self.max_num_reqs + 1 - num_reqs] + * self.uniform_decode_query_len + cu_num_tokens[-1]) self.query_start_loc.copy_to_gpu() self.seq_lens.np[:num_reqs] = ( @@ -714,9 +681,25 @@ class NPUModelRunner(GPUModelRunner): # Copy the tensors to the NPU. self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, cu_num_tokens) - self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_() - self.positions.copy_to_gpu() - attn_metadata: dict[str, Any] = {} + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self._calc_mrope_positions(scheduler_output) + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) + elif self.uses_xdrope_dim > 0: + self._calc_xdrope_positions(scheduler_output) + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) + else: + # Common case (1D positions) + self.positions.copy_to_gpu(total_num_scheduled_tokens) # Record the index of requests that should not be sampled, # so that we could clear the sampled tokens before returning @@ -730,7 +713,7 @@ class NPUModelRunner(GPUModelRunner): # while pcp > 1, we need the original num_scheduled_tokens before split # to calculate discard_requests_mask tokens_original = [ - scheduler_output.num_scheduled_tokens[i] for i in req_ids + scheduler_output.num_scheduled_tokens[i] for i in self.input_batch.req_ids ] original_seq_lens_np = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + @@ -744,9 +727,7 @@ class NPUModelRunner(GPUModelRunner): self.discard_request_indices.np[:self.num_discarded_requests] = ( discard_request_indices) self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) - - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -754,6 +735,8 @@ class NPUModelRunner(GPUModelRunner): # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. spec_decode_metadata = None + num_draft_tokens = None + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) if self.pcp_size * self.dcp_size > 1: logits_indices = self.pcp_manager.get_logits_indices( cu_num_tokens, num_reqs) @@ -769,8 +752,10 @@ class NPUModelRunner(GPUModelRunner): # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( @@ -783,264 +768,31 @@ class NPUModelRunner(GPUModelRunner): num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs] if self.pcp_size > 1 else None) logits_indices = spec_decode_metadata.logits_indices + num_sampled_tokens = num_draft_tokens + 1 # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() # save logits_indices for pcp spec decode usage self.logits_indices = logits_indices - # Used in the below loop. - self.spec_decode_common_attn_metadata = None - if use_spec_decode and self.need_accepted_tokens: - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() - - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens( - scheduler_output.num_scheduled_tokens or {}, - kv_cache_group_spec.kv_cache_spec, - self.input_batch.num_reqs, + # Hot-Swap lora model + if self.lora_config: + assert ( + np.sum(num_sampled_tokens) + <= self.vllm_config.scheduler_config.max_num_batched_tokens + ) + self.set_active_loras( + self.input_batch, num_scheduled_tokens, num_sampled_tokens ) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): - # Encoder-only layers do not have KV cache, so we need to - # create a dummy block table and slot mapping for them. - blk_table_tensor = torch.zeros( - (num_reqs, 1), - dtype=torch.int32, - device=self.device, - ) - slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), - dtype=torch.int64, - device=self.device, - ) - else: - maybe_pcp_full_tokens = ( - num_input_tokens if self.pcp_size == 1 else - total_num_scheduled_tokens * self.pcp_size - - sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])) - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor() - slot_mapping = blk_table.slot_mapping.gpu[: - maybe_pcp_full_tokens] - if self.pcp_size == 1: - slot_mapping[ - total_num_scheduled_tokens:num_input_tokens].fill_(-1) - if self.pcp_size * self.dcp_size > 1: - self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( - total_num_scheduled_tokens, self.query_lens, - self.input_batch, num_scheduled_tokens) - blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) - if self.pcp_size > 1: - slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping( - total_num_scheduled_tokens, - slot_mapping, - ) - blk_table.slot_mapping.gpu[:self.pcp_manager. - num_actual_tokens_pcp_padded] = slot_mapping_pcp - slot_mapping = blk_table.slot_mapping.gpu[:self. - pcp_manager. - num_actual_tokens_pcp_padded] - - # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs - # has been split to multiple parts, and there are 3 parts that is related to this - # `num_reqs`, we'll take `query_start_loc` as an example: - # 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens - # 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the - # following simplified `dispatch` logic here, we try to minimize the impact - # 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1] - uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \ - and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs) - - # TODO: We should make this official ASAP. Also note that if we pad here, - # the builders won’t need to add any extra padding. - if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - uniform_decode: - max_decode_tokens = min( - self.scheduler_config.max_num_seqs * - self.uniform_decode_query_len, - self.cudagraph_batch_sizes[-1]) - if self.uniform_decode_query_len <= num_input_tokens <= max_decode_tokens: - num_reqs_padded = num_input_tokens // self.uniform_decode_query_len - pad_size = num_reqs_padded - num_reqs - if pad_size > 0: - last_query_loc = self.query_start_loc.np[num_reqs] - - self.query_start_loc.np[ - num_reqs + 1:num_reqs_padded + 1] = self.arange_np[ - 1:pad_size + - 1] * self.uniform_decode_query_len + last_query_loc - self.query_start_loc.copy_to_gpu(num_reqs_padded + 1) - self.seq_lens.np[num_reqs:].fill(0) - self.seq_lens.copy_to_gpu(num_reqs_padded) - - # So we are trying to simulate the behavior of GPUModelRunner's - # prepare_inputs for uniform decode mode by padding query_start_loc - num_reqs = num_reqs_padded - - # Make AscendCommonAttentionMetadata - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.gpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - num_input_tokens=num_input_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - # TODO: change this to the right block table for linear attn - block_table_tensor=blk_table_tensor[:num_reqs], - slot_mapping=slot_mapping, - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - positions=self.positions.gpu, - attn_state=self.attn_state, - max_query_len=max_num_scheduled_tokens, - decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=self.long_seq_metadata, - max_seq_len=0, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_cpu=encoder_seq_lens_cpu) - - if self.speculative_config and self.pcp_size * self.dcp_size > 1: - # For pcp + spec decode, we flatten block_table - # to avoid irregular attn_mask shape, e.g., - # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, - # ori block_table: # [d0, d1, p0, p1, p2] - # (num_reqs_d + num_reqs_p, max_num_blocks), - # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] - # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[: - num_reqs] - ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[: - num_reqs] - num_prefill_reqs = (ori_query_lens - > self.decode_threshold).sum().item() - num_decode_reqs = num_reqs - num_prefill_reqs - num_decode_reqs_flatten = \ - ori_query_lens_cpu[:num_decode_reqs].sum().item() - blk_table_tensor[ - num_decode_reqs_flatten:num_decode_reqs_flatten + - num_prefill_reqs].copy_( - blk_table_tensor[num_decode_reqs:num_decode_reqs + - num_prefill_reqs].clone()) - blk_table_tensor[:num_decode_reqs_flatten].copy_( - blk_table_tensor[:num_decode_reqs].repeat_interleave( - ori_query_lens[:num_decode_reqs], dim=0)) - common_attn_metadata.block_table_tensor = \ - blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] - assert self.long_seq_metadata is not None - self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu - - if 'pad_size' in locals() and pad_size > 0: - ori_query_lens_cpu[-pad_size:] = \ - torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) - self.long_seq_metadata.max_query_len_pcp_full = \ - ori_query_lens_cpu.max().item() - - - - if self.speculative_config and \ - self.spec_decode_common_attn_metadata is None: - self.spec_decode_common_attn_metadata = common_attn_metadata - if num_reqs != base_num_reqs or total_num_scheduled_tokens != num_input_tokens: - self.spec_decode_common_attn_metadata = \ - self.spec_decode_common_attn_metadata.unpadded( - total_num_scheduled_tokens, base_num_reqs) - - for attn_group in self.attn_groups[kv_cache_group_id]: - common_prefix_len = 0 - extra_attn_metadata_args = {} - builder = attn_group.get_metadata_builder() - if isinstance(builder, GDNAttentionMetadataBuilder): - if use_spec_decode: - patch_torch_npu_argsort() - extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], - ) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i - if lmhead_tp_enable(): max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len logits_indices = nn.functional.pad( logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) - return (attn_metadata, num_scheduled_tokens, num_input_tokens, - num_tokens_across_dp, logits_indices, spec_decode_metadata, - max_num_scheduled_tokens) - - # all-gather one hidden-states in sp scene - @staticmethod - def _all_gather_hidden_states(hidden_states): - hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) - pad_size = get_forward_context().pad_size - if pad_size > 0: - hidden_states = hidden_states[:-pad_size, :] - - return hidden_states - - # all-gather a list of hidden-states in sp scene - @staticmethod - def _all_gather_hidden_states_list(hidden_states_list): - return [ - NPUModelRunner._all_gather_hidden_states(hidden_states) - for hidden_states in hidden_states_list - ] - - # all-gather hidden-states in last layer with aux-hidden-states in sp scene - @staticmethod - def _all_gather_hidden_states_and_aux(hidden_states): - if isinstance(hidden_states, tuple): - return (NPUModelRunner._all_gather_hidden_states(hidden_states[0]), - NPUModelRunner._all_gather_hidden_states_list( - hidden_states[1])) - return NPUModelRunner._all_gather_hidden_states(hidden_states) - - def _generate_process_reqs_hidden_states(self, num_input_tokens, - input_ids, positions, - intermediate_tensors, - inputs_embeds, model_kwargs): - assert self.model is not None - hidden_states = self.model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs) - - forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ - and not self.use_sparse: - update_full_graph_params(self.attn_backend, self.update_stream, forward_context, - num_input_tokens, self.vllm_config, - self.vllm_config.speculative_config) - - if get_forward_context().sp_enabled and not isinstance( - hidden_states, IntermediateTensors): - hidden_states = self._all_gather_hidden_states_and_aux( - hidden_states) - if self.pcp_size > 1 and get_pp_group().is_last_rank: - hidden_states = self.pcp_manager.get_restore_hidden_states( - hidden_states) - return hidden_states + return logits_indices, spec_decode_metadata def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): @@ -1168,10 +920,11 @@ class NPUModelRunner(GPUModelRunner): sampling_metadata: SamplingMetadata, scheduler_output: "SchedulerOutput", spec_decode_metadata: SpecDecodeMetadata, + spec_decode_common_attn_metadata: AscendCommonAttentionMetadata, positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: dict[str, Any], + attn_metadata: list[dict[str, Any]] | dict[str, Any], aux_hidden_states: torch.Tensor = None, sample_hidden_states: torch.Tensor = None ) -> Optional[list[list[int]]]: @@ -1189,7 +942,7 @@ class NPUModelRunner(GPUModelRunner): valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states) elif self.speculative_config.use_eagle(): - common_attn_metadata = self.spec_decode_common_attn_metadata + common_attn_metadata = spec_decode_common_attn_metadata sampled_token_ids = valid_sampled_token_ids if self.vllm_config.speculative_config.disable_padded_drafter_batch: @@ -1325,64 +1078,169 @@ class NPUModelRunner(GPUModelRunner): return draft_token_ids - @staticmethod - def get_finished_kv_transfer( - scheduler_output: "SchedulerOutput", - ) -> tuple[Optional[set[str]], Optional[set[str]]]: - if has_kv_transfer_group(): - return get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids) - return None, None - @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors] | None: + ) -> ModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called " "after execute_model() returns None.") - + # self._draft_token_ids is None when `input_fits_in_drafter=False` + # and there is no draft tokens scheduled. so it need to update the + # spec_decoding info in scheduler_output with async_scheduling. + # use deepcopy to avoid the modification has influence on the + # scheduler_output in engine core process. + # TODO(Ronald1995): deepcopy is expensive when there is a large + # number of requests, optimize it later. + if ( + self.use_async_scheduling + and self.num_spec_tokens + and self._draft_token_ids is None # type: ignore[has-type] + ): + scheduler_output = deepcopy(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with ProfileExecuteDuration().capture_async("prepare input"): - self._update_states(scheduler_output) - if has_ec_transfer() and get_ec_transfer().is_producer: - with self.maybe_get_ec_connector_output( + with self.synchronize_input_prep(): + # Update persistent batch states. + self._update_states(scheduler_output) + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( scheduler_output, encoder_cache=self.encoder_cache, - ) as ec_connector_output: - self._execute_mm_encoder(scheduler_output) - return make_empty_encoder_model_runner_output( - scheduler_output) + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - logger.debug( - "skip this step for we receive the data from remote disaggregate prefill node" + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward( + scheduler_output, self.vllm_config ) - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - if self.dynamic_eplb: - self.eplb_updator.forward_before() + ( + logits_indices, + spec_decode_metadata, + ) = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + if self.pcp_size > 1: + num_tokens_unpadded = self.pcp_manager.total_num_sampled_tokens_pcp + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, + ) - (attn_metadata, num_scheduled_tokens_np, num_input_tokens, - num_tokens_across_dp, logits_indices, spec_decode_metadata, - max_query_len) = self._prepare_inputs(scheduler_output) + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) - (input_ids, inputs_embeds, positions, intermediate_tensors, - model_kwargs, ec_connector_output) = self._preprocess(scheduler_output, - num_input_tokens, - intermediate_tensors) + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + (attn_metadata, spec_decode_common_attn_metadata) = ( + self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + num_scheduled_tokens_np=num_scheduled_tokens_np, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + ) + ) + + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_tokens_padded, intermediate_tensors + ) # update global cos, sin update_cos_sin(positions) - - if self.dynamic_eplb: - self.eplb_updator.take_update_info_from_eplb_process() - + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: # type: ignore[has-type] + cudagraph_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False # type: ignore[has-type] # prevent debugger is None if self.debugger is not None: dbg_cfg = getattr(self.debugger, "config", None) @@ -1393,14 +1251,6 @@ class NPUModelRunner(GPUModelRunner): self.debugger.start(model=self.model) else: self.debugger.start() - - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - scheduler_output.total_num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len) - has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 - aclgraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) - if self.ascend_config.enable_async_exponential: self.sampler.do_async_exponential( b_s=logits_indices.shape[0], @@ -1420,78 +1270,89 @@ class NPUModelRunner(GPUModelRunner): set_ascend_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_input_tokens, + num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, num_actual_tokens=scheduler_output. total_num_scheduled_tokens, model_instance=self.model, skip_compiled=has_encoder_input), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): - hidden_states = self._generate_process_reqs_hidden_states( - num_input_tokens, input_ids, positions, - intermediate_tensors, inputs_embeds, model_kwargs) - + hidden_states = self._model_forward( + num_tokens_padded, input_ids, positions, + intermediate_tensors, inputs_embeds, **model_kwargs) + with (ProfileExecuteDuration().capture_async("post process")): + if self.pcp_size > 1: + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + hidden_states = self.pcp_manager.get_restore_hidden_states( + hidden_states + ) aux_hidden_states = None if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = hidden_states - with ProfileExecuteDuration().capture_async("post process"): - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - if not broadcast_pp_output: + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output self.kv_connector_output = kv_connector_output if self.debugger is not None: self.debugger.stop() self.debugger.step() return hidden_states - assert isinstance(hidden_states, IntermediateTensors) - get_pp_group().send_tensor_dict( - hidden_states.tensors, all_gather_group=get_tp_group()) - logits = None - else: - if self.input_batch.pooling_params: - pool_output = self._pool( - hidden_states, - scheduler_output.total_num_scheduled_tokens, - num_scheduled_tokens_np, kv_connector_output) + if self.is_pooling_model: + # Return the pooling output. + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np, kv_connector_output + ) + output.kv_connector_output = kv_connector_output if self.debugger is not None: self.debugger.stop() self.debugger.step() - return pool_output + return output + sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} - model_output_broadcast_data = get_pp_group( - ).broadcast_tensor_dict(model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] + else: + # Rare case. + assert not self.is_pooling_model + + if not get_pp_group().is_last_rank: + sample_hidden_states = hidden_states[logits_indices] + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data: dict[str, Any] = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + broadcasted = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert broadcasted is not None + logits = broadcasted["logits"] + # Apply structured output bitmasks if present self.execute_model_state = ExecuteModelState( scheduler_output, logits, spec_decode_metadata, + spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, aux_hidden_states, attn_metadata, positions, - ec_connector_output + ec_connector_output, ) self.kv_connector_output = kv_connector_output return None @@ -1521,12 +1382,13 @@ class NPUModelRunner(GPUModelRunner): scheduler_output, logits, spec_decode_metadata, + spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, aux_hidden_states, attn_metadata, positions, - ec_connector_output + ec_connector_output, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -1545,12 +1407,13 @@ class NPUModelRunner(GPUModelRunner): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): - assert self.spec_decode_common_attn_metadata is not None + assert spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( sampled_token_ids, self.input_batch.sampling_metadata, scheduler_output, spec_decode_metadata, + spec_decode_common_attn_metadata, positions, scheduler_output.total_num_scheduled_tokens, hidden_states, @@ -1592,16 +1455,17 @@ class NPUModelRunner(GPUModelRunner): if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() - model_runner_output = ModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - ec_connector_output=ec_connector_output, kv_connector_output=kv_connector_output, pooler_output=[], + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, ) durations = ProfileExecuteDuration().pop_captured_sync() @@ -1615,17 +1479,12 @@ class NPUModelRunner(GPUModelRunner): " ".join(dr_str)) if self.dynamic_eplb: self.eplb_updator.forward_end() - if not self.use_async_scheduling: - if self.debugger is not None: - assert self.debugger is not None - self.debugger.stop() - self.debugger.step() - return model_runner_output if self.debugger is not None: - assert self.debugger is not None self.debugger.stop() self.debugger.step() + if not self.use_async_scheduling: + return model_runner_output return AsyncGPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampler_output.sampled_token_ids, @@ -1786,146 +1645,432 @@ class NPUModelRunner(GPUModelRunner): invalid_req_indices, ) - def _build_dummy_attn_metadata( + # all-gather one hidden-states in sp scene + @staticmethod + def _all_gather_hidden_states(hidden_states): + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] + + return hidden_states + + # all-gather a list of hidden-states in sp scene + @staticmethod + def _all_gather_hidden_states_list(hidden_states_list): + return [ + NPUModelRunner._all_gather_hidden_states(hidden_states) + for hidden_states in hidden_states_list + ] + + # all-gather hidden-states in last layer with aux-hidden-states in sp scene + @staticmethod + def _all_gather_hidden_states_and_aux(hidden_states): + if isinstance(hidden_states, tuple): + return (NPUModelRunner._all_gather_hidden_states(hidden_states[0]), + NPUModelRunner._all_gather_hidden_states_list( + hidden_states[1])) + return NPUModelRunner._all_gather_hidden_states(hidden_states) + + def _model_forward( self, - with_prefill: bool, - num_reqs: int, - num_tokens: int, - max_query_len: int, - num_scheduled_tokens: np.ndarray, - aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, - force_attention: bool = False, - is_graph_capturing: bool = False, - ) -> Optional[dict[str, Any]]: - attn_metadata: Optional[dict[str, Any]] = None - - if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL: - assert with_prefill is False, \ - "Full decode graph only supports uniform batch now." - - attn_metadata = {} - - # The reason why we use a fixed seq_len rather than max_query_len is that - # _npu_paged_attention_get_workspace only returns max workspace with specific - # seq_lens. We use this seq_len only when capturing graph, and still use max_query_len - # in inference. This will be removed once npu_fused_infer_attention_score - # outperforms _npu_paged_attention on all cases. - seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len - self.seq_lens.np[:num_reqs] = seq_lens - self.seq_lens.np[num_reqs:] = 0 - self.seq_lens.copy_to_gpu() - - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - self.query_start_loc.cpu[1:num_reqs + - 1] = torch.Tensor(cu_num_tokens) - self.query_lens = torch.from_numpy(num_scheduled_tokens) - - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) - - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_table_tensor = self.input_batch.block_table[ - kv_cache_group_id].get_device_tensor() - slot_mapping = self.input_batch.block_table[ - kv_cache_group_id].slot_mapping - long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( - num_tokens, self.query_lens, self.input_batch, - num_scheduled_tokens) - if long_seq_metadata is not None: - pcp_world_size = get_pcp_group().world_size - dcp_world_size = get_dcp_group().world_size - num_computed_tokens_of_pcp_dcp = [[ - [0] * dcp_world_size for _ in range(pcp_world_size) - ] for _ in range(num_tokens)] - long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp - - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], - seq_lens_cpu=self.seq_lens.cpu, - seq_lens=self.seq_lens.gpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - num_input_tokens=num_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=slot_mapping.gpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, - positions=self.positions.gpu, - attn_state=self.attn_state, - max_query_len=max_query_len, - decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=long_seq_metadata, - max_seq_len=0) - if self.pcp_size * self.dcp_size > 1: - common_attn_metadata.block_table_tensor = \ - block_table_tensor[:num_reqs * self.decode_threshold] - attn_state = AscendAttentionState.DecodeOnly - if self.speculative_config and \ - self.speculative_config.method == "mtp": - # `AscendAttentionState.SpecDecoding` is only designed for mla - if self.vllm_config.model_config.use_mla: - attn_state = AscendAttentionState.SpecDecoding - else: - attn_state = AscendAttentionState.ChunkedPrefill - - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], - _seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.gpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=slot_mapping.gpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, - max_query_len=max_query_len, - max_seq_len=seq_lens) - - for attn_group in self.attn_groups[kv_cache_group_id]: - builder = attn_group.get_metadata_builder() - if isinstance(builder, GDNAttentionMetadataBuilder): - attn_metadata_gdn_attention = builder.build_for_cudagraph_capture( - common_metadata) - else: - attn_metadata_full_attention = builder.build_for_graph_capture( - common_attn_metadata, attn_state) - for layer_name in kv_cache_group_spec.layer_names: - if "linear_attn" in layer_name: - attn_metadata[ - layer_name] = attn_metadata_gdn_attention - else: - attn_metadata[ - layer_name] = attn_metadata_full_attention - - return attn_metadata - - def _generate_dummy_run_hidden_states(self, input_ids, positions, - num_tokens, intermediate_tensors, - inputs_embeds): - hidden_states = self.model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + num_tokens_padded: int, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any],): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs) forward_context = get_forward_context() assert forward_context is not None if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ not forward_context.capturing and not self.use_sparse: + assert positions is not None update_full_graph_params(self.attn_backend, self.update_stream, forward_context, - num_tokens, self.vllm_config, + num_tokens_padded, self.vllm_config, self.speculative_config, positions.shape[0]) - - if self.use_aux_hidden_state_outputs: - hidden_states, _ = hidden_states - else: - hidden_states = hidden_states + if get_forward_context().sp_enabled and not isinstance( + hidden_states, IntermediateTensors): + hidden_states = self._all_gather_hidden_states_and_aux( + hidden_states) return hidden_states + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if enable_sp(): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _sync_batch_across_dp( + self, + num_tokens_padded: int | None = None, + cudagraph_mode: int = 0, + ) -> tuple[bool, torch.Tensor | None, int]: + """ + Coordinates amongst all DP ranks to determine if and how the full batch + should be split into microbatches. + + Args: + num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs, + TP, etc) + cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL) + + Returns: tuple[ + ubatch_slices: if this is set then all DP ranks have agreed to + microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + padded up to the max value across all DP ranks when allow_dp_padding + is True. + synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks) + ] + + """ + + # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in + # our case, we still need to sync the other two flags as well. So we need to + # include them in the all_reduce operation, and more over, we CANNOT skip it + # even if we are running in eager mode, which harms performance. + # FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here + # immediately once the other two flags are no longer needed. + + if self.dp_size == 1: + return False, None, cudagraph_mode + + if self._skip_all_reduce_across_dp_group(): + num_tokens_after_padding = torch.tensor([num_tokens_padded] * + self.dp_size, + device="cpu", + dtype=torch.int32) + return False, num_tokens_after_padding, cudagraph_mode + + tensor = torch.zeros(2, self.dp_size, device="cpu", dtype=torch.int32) + tensor[0][self.dp_rank] = num_tokens_padded + tensor[1][self.dp_rank] = cudagraph_mode + dist.all_reduce(tensor, group=get_dp_group().cpu_group) + + num_tokens_across_dp = tensor[0, :] + max_num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens_after_padding = torch.tensor( + [max_num_tokens] * len(num_tokens_across_dp), + device="cpu", + dtype=torch.int32, + ) + # Synchronize cudagraph_mode across ranks (take min) + synced_cudagraph_mode = _post_process_cudagraph_mode(tensor) + return False, num_tokens_after_padding, synced_cudagraph_mode + + def _determine_batch_execution_and_padding( + self, + num_tokens: int, + num_reqs: int, + num_scheduled_tokens_np: np.ndarray, + max_num_scheduled_tokens: int, + use_cascade_attn: bool, + allow_microbatching: bool = False, + force_eager: bool = False, + # For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will + # be improved in model runner v2) + force_uniform_decode: bool | None = None, + force_has_lora: bool | None = None, + num_encoder_reqs: int = 0, + ) -> tuple[CUDAGraphMode, BatchDescriptor, bool, + torch.Tensor | None, CUDAGraphStat | None]: + + num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) + uniform_decode = ( + ((max_num_scheduled_tokens == self.uniform_decode_query_len) and + (num_tokens == max_num_scheduled_tokens * num_reqs)) + if force_uniform_decode is None else force_uniform_decode) + # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output + # is present). Also, chunked-prefill is disabled, so batch are uniform. + has_encoder_output = (self.model_config.is_encoder_decoder + and num_encoder_reqs > 0) + has_lora = (len(self.input_batch.lora_id_to_lora_request) > 0 + if force_has_lora is None else force_has_lora) + + # ruff: noqa: E731 + dispatch_cudagraph = ( + lambda num_tokens, disable_full: self.cudagraph_dispatcher. + dispatch( + num_tokens=num_tokens, + has_lora=has_lora, + uniform_decode=uniform_decode, + disable_full=disable_full, + ) if not force_eager else + (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))) + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, use_cascade_attn or has_encoder_output) + num_tokens_padded = batch_descriptor.num_tokens + if enable_sp(self.vllm_config): + assert (batch_descriptor.num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0 + ), ("Sequence parallelism requires num_tokens to be " + "a multiple of tensor parallel size") + # Extra coordination when running data-parallel since we need to coordinate + # across ranks + should_ubatch, num_tokens_across_dp = False, None + if self.vllm_config.parallel_config.data_parallel_size > 1: + _, num_tokens_across_dp, synced_cudagraph_mode = self._sync_batch_across_dp(num_tokens_padded=num_tokens_padded, + cudagraph_mode=cudagraph_mode.value, + ) + + # Extract DP padding if there is any + if num_tokens_across_dp is not None: + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) + # Re-dispatch with DP padding + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, + disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,) + # Assert to make sure the agreed upon token count is correct otherwise + # num_tokens_across_dp will no-longer be valid + assert batch_descriptor.num_tokens == num_tokens_padded + cudagraph_stats = None + if self.vllm_config.observability_config.cudagraph_metrics: + cudagraph_stats = CUDAGraphStat( + num_unpadded_tokens=num_tokens, + num_padded_tokens=batch_descriptor.num_tokens, + num_paddings=batch_descriptor.num_tokens - num_tokens, + runtime_mode=str(cudagraph_mode), + ) + + return ( + cudagraph_mode, + batch_descriptor, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) + + def _build_attention_metadata( + self, + num_tokens: int, + num_reqs: int, + max_query_len: int, + num_tokens_padded: int | None = None, + num_reqs_padded: int | None = None, + ubatch_slices: UBatchSlices | None = None, + logits_indices: torch.Tensor | None = None, + use_spec_decode: bool = False, + for_cudagraph_capture: bool = False, + num_scheduled_tokens: dict[str, int] | None = None, + num_scheduled_tokens_np: np.ndarray | None = None, + cascade_attn_prefix_lens: list[list[int]] | None = None, + ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: + """ + :return: tuple[attn_metadata, spec_decode_common_attn_metadata] + """ + # Attention metadata is not needed for attention free models + if len(self.kv_cache_config.kv_cache_groups) == 0: + return {}, None + num_tokens_padded = num_tokens_padded or num_tokens + num_reqs_padded = num_reqs_padded or num_reqs + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + if for_cudagraph_capture: + # For some attention backends (e.g. FA) with sliding window models we need + # to make sure the backend see a max_seq_len that is larger to the sliding + # window size when capturing to make sure the correct kernel is selected. + max_seq_len = self.max_model_len + else: + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + kv_cache_groups = self.kv_cache_config.kv_cache_groups + + def _get_pcp_metadata(num_tokens): + if not self.use_cp: + return None + return self.pcp_manager.generate_pcp_metadata(num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np) + + def _get_block_table_and_slot_mapping(kv_cache_gid: int): + assert num_reqs_padded is not None and num_tokens_padded is not None + kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec + maybe_pcp_full_tokens = ( + num_tokens_padded if self.pcp_size == 1 else + num_tokens * self.pcp_size - + sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + blk_table_tensor = torch.zeros( + (num_reqs_padded, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (num_tokens_padded,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_gid] + slot_mapping = blk_table.slot_mapping.gpu[:maybe_pcp_full_tokens] + maybe_num_reqs_padded = num_reqs_padded * self.decode_token_per_req if self.use_cp else num_reqs_padded + blk_table_tensor = blk_table.get_device_tensor()[:maybe_num_reqs_padded] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + if self.pcp_size == 1: + slot_mapping[num_tokens:num_tokens_padded].fill_(-1) + blk_table_tensor[num_reqs:num_reqs_padded].fill_(0) + if self.pcp_size > 1: + slot_mapping = self.pcp_manager.get_padded_slot_mapping( + num_tokens, + num_tokens_padded, + slot_mapping, + ) + return blk_table_tensor, slot_mapping + + long_seq_metdadata = _get_pcp_metadata(num_tokens) + block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) + + cm_base = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], + seq_lens=self.seq_lens.gpu[:num_reqs_padded], + # TODO + seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], + # TODO + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ], + num_reqs=num_reqs_padded, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_gid_0, + slot_mapping=slot_mapping_gid_0, + causal=True, + num_input_tokens=num_tokens_padded, + actual_seq_lengths_q=self.actual_seq_lengths_q, + positions=self.positions.gpu, + attn_state=self.attn_state, + decode_token_per_req=self.decode_token_per_req, + prefill_context_parallel_metadata=long_seq_metdadata, + ) + + if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: + cm_base.num_logits_indices = logits_indices.size(0) + cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices + ) + + def _build_attn_group_metadata( + kv_cache_gid: int, + attn_gid: int, + common_attn_metadata: CommonAttentionMetadata, + ubid: int | None = None, + ) -> None: + attn_group = self.attn_groups[kv_cache_gid][attn_gid] + builder = attn_group.get_metadata_builder(ubid or 0) + cascade_attn_prefix_len = ( + cascade_attn_prefix_lens[kv_cache_gid][attn_gid] + if cascade_attn_prefix_lens + else 0 + ) + + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + assert ubid is None, "UBatching not supported with GDN yet" + patch_torch_npu_argsort() + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs_padded + ], + ) + + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + + if ubid is None: + assert isinstance(attn_metadata, dict) + attn_metadata_dict = attn_metadata + else: + assert isinstance(attn_metadata, list) + attn_metadata_dict = attn_metadata[ubid] + + for layer_name in attn_group.layer_names: + attn_metadata_dict[layer_name] = attn_metadata_i + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + spec_decode_common_attn_metadata = None + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups): + cm = copy(cm_base) # shallow copy + # Basically only the encoder seq_lens, block_table and slot_mapping change + # for each kv_cache_group. + cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens( + num_scheduled_tokens or {}, + kv_cache_group.kv_cache_spec, + num_reqs_padded, + ) + if kv_cache_gid > 0: + cm.block_table_tensor, cm.slot_mapping = ( + _get_block_table_and_slot_mapping(kv_cache_gid) + ) + if self.speculative_config and spec_decode_common_attn_metadata is None: + if isinstance(self.drafter, EagleProposer): + if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: + spec_decode_common_attn_metadata = cm + else: + spec_decode_common_attn_metadata = cm + + for attn_gid in range(len(self.attn_groups[kv_cache_gid])): + _build_attn_group_metadata(kv_cache_gid, attn_gid, cm) + if self.is_mm_prefix_lm: + req_doc_ranges = {} + for req_id in self.input_batch.req_ids: + image_doc_ranges = [] + req_state = self.requests[req_id] + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + img_doc_range = pos_info.extract_embeds_range() + image_doc_ranges.extend(img_doc_range) + req_idx = self.input_batch.req_id_to_index[req_id] + req_doc_ranges[req_idx] = image_doc_ranges + + if isinstance(attn_metadata, list): + for ub_metadata in attn_metadata: + for _metadata in ub_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + else: + for _metadata in attn_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + + if spec_decode_common_attn_metadata is not None and ( + num_reqs != num_reqs_padded or num_tokens != num_tokens_padded + ): + # Currently the drafter still only uses piecewise cudagraphs (and modifies + # the attention metadata in directly), and therefore does not want to use + # padded attention metadata. + spec_decode_common_attn_metadata = ( + spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) + ) + return attn_metadata, spec_decode_common_attn_metadata + @torch.inference_mode() def _dummy_run( self, @@ -1935,37 +2080,17 @@ class NPUModelRunner(GPUModelRunner): force_attention: bool = False, uniform_decode: bool = False, is_profile: bool = False, + create_mixed_batch: bool = False, allow_microbatching: bool = True, skip_eplb: bool = False, remove_lora: bool = True, activate_lora: bool = False, is_graph_capturing: bool = False, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: # only support eager mode and piecewise graph now - assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - } - # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. - # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. - if self.use_aclgraph and enable_sp(self.vllm_config): - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - num_tokens = math.ceil(num_tokens / tp_size) * tp_size - - # Force dummy run on prefill stage when this node is deemed as kv producer. - if self.is_kv_producer and not self.is_kv_consumer: - with_prefill = True - - has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False - _ag_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora) - - # Padding for DP - (num_tokens, num_tokens_across_dp, - with_prefill) = self._sync_metadata_across_dp( - batch_descriptor.num_tokens, with_prefill) - + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() # If cudagraph_mode.decode_mode() == FULL and - # cudagraph_mode.seperate_routine(). This means that we are using + # cudagraph_mode.separate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. # uniform decode batches. A uniform decode batch means that all # requests have identical query length, except a potential virtual @@ -1977,79 +2102,112 @@ class NPUModelRunner(GPUModelRunner): # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens - + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens - max_num_reqs = self.max_num_reqs - if uniform_decode: - num_reqs = cdiv(num_tokens, max_query_len) + max_num_reqs = self.scheduler_config.max_num_seqs + if create_mixed_batch: + raise NotImplementedError("create_mixed_batch is used for warmup deepgemm, vllm-ascend does not need it") + elif uniform_decode: + assert not create_mixed_batch + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len else: - if with_prefill: - num_reqs = num_tokens - else: - num_reqs = (num_tokens + self.decode_token_per_req - - 1) // self.decode_token_per_req - num_reqs = min(num_reqs, max_num_reqs) + num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + self.query_lens = torch.from_numpy(num_scheduled_tokens) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - - if not is_profile and self.dynamic_eplb: - self.eplb_updator.forward_before() - - if num_tokens != batch_descriptor.num_tokens: - _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - num_tokens=num_tokens, - uniform_decode=uniform_decode, - has_lora=has_lora) - - num_tokens_padded = batch_descriptor.num_tokens - num_reqs_padded = (batch_descriptor.num_reqs if - batch_descriptor.num_reqs is not None else num_reqs) + _cudagraph_mode, batch_desc, _, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile + or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, + ) + ) + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cudagraph_mode + else: + assert cudagraph_runtime_mode == _cudagraph_mode, ( + f"Cudagraph runtime mode mismatch in dummy_run. " + f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." + ) + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) if num_tokens_across_dp is not None and num_tokens_padded != num_tokens: # pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher num_tokens_across_dp[:] = num_tokens_padded num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) + # vllm-ascend does not support ubatch now + ubatch_slices, ubatch_slices_padded = None, None + attn_metadata: PerLayerAttnMetadata | None = None + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if create_mixed_batch: + raise NotImplementedError("create_mixed_batch is used for warmup deepgemm, vllm-ascend does not need it") + self.attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and \ + self.speculative_config.method == "mtp": + # `AscendAttentionState.SpecDecoding` is only designed for mla + if self.vllm_config.model_config.use_mla: + self.attn_state = AscendAttentionState.SpecDecoding + else: + self.attn_state = AscendAttentionState.ChunkedPrefill + # The reason why we use a fixed seq_len rather than max_query_len is that + # _npu_paged_attention_get_workspace only returns max workspace with specific + # seq_lens. We use this seq_len only when capturing graph, and still use max_query_len + # in inference. This will be removed once npu_fused_infer_attention_score + # outperforms _npu_paged_attention on all cases. + seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len # type: ignore[assignment] + self.seq_lens.np[:num_reqs_padded] = seq_lens + self.seq_lens.np[num_reqs_padded:] = 0 + self.seq_lens.copy_to_gpu() + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, + for_cudagraph_capture=is_graph_capturing, + num_scheduled_tokens_np=num_scheduled_tokens, + ) - # filter out the valid batch descriptor - if cudagraph_runtime_mode is not None: - # we allow forcing NONE when the dispatcher disagrees to support - # warm ups for aclgraph capture - if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode: - raise ValueError( - f"Aclgraph runtime mode mismatch at dummy_run. " - f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}.") - else: - cudagraph_runtime_mode = _ag_mode - - # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup - # and not supported in ASCEND now. We could remove it in the future. - attn_metadata = self._build_dummy_attn_metadata( - False, - num_reqs=num_reqs_padded, - num_tokens=num_tokens_padded, - max_query_len=max_query_len, - aclgraph_runtime_mode=cudagraph_runtime_mode, - force_attention=force_attention, - is_graph_capturing=is_graph_capturing, - num_scheduled_tokens=num_scheduled_tokens, - ) - - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, - num_sampled_tokens): + with self.maybe_dummy_run_with_lora( + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens if self.is_multimodal_model and not self.model_config.is_encoder_decoder: @@ -2122,14 +2280,18 @@ class NPUModelRunner(GPUModelRunner): num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, in_profile_run=is_profile, - num_actual_tokens=0, + num_actual_tokens=num_tokens_padded, aclgraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + batch_descriptor=batch_desc, model_instance=self.model): - hidden_states = self._generate_dummy_run_hidden_states( - input_ids, positions, num_tokens_padded, + outputs = self._model_forward( + num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds) - dummy_compute_logits(hidden_states) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + dummy_compute_logits(hidden_states) if self.drafter: self.drafter.dummy_run( @@ -2138,7 +2300,7 @@ class NPUModelRunner(GPUModelRunner): num_reqs=num_reqs_padded, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + batch_descriptor=batch_desc, dummy_compute_logits=dummy_drafter_compute_logits, in_graph_capturing=not force_attention, is_profile=is_profile) @@ -2866,6 +3028,15 @@ class NPUModelRunner(GPUModelRunner): mm_data[field] = tensor.cpu() +def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int: + """ + Synchronize cudagraph_mode across DP ranks by taking the minimum. + If any rank has NONE (0), all ranks use NONE. + This ensures all ranks send consistent values (all padded or all unpadded). + """ + return int(tensor[1, :].min().item()) + + @contextmanager def _torch_cuda_wrapper(): diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 873c4f1c..5b221e57 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -75,11 +75,13 @@ class PCPManager: dtype=torch.int32, device=device, ) + self.pcp_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + self.total_num_sampled_tokens_pcp = 0 self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), device="cpu", dtype=torch.int64) self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() - self.pcp_unpad_mask_cpu_tensor = torch.zeros( + self.pcp_unpad_mask_cpu_tensor = torch.ones( (max_buffer_num_tokens, ), device="cpu", dtype=torch.bool, @@ -292,6 +294,8 @@ class PCPManager: all_positions.argsort()) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + self.pcp_tokens[:num_reqs] = pcp_tokens[:num_reqs] + self.total_num_sampled_tokens_pcp = pcp_tokens[:num_reqs].sum() return ( pcp_tokens[:num_reqs], positions, @@ -312,17 +316,16 @@ class PCPManager: num_scheduled_tokens * self.pcp_world_size - self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np - def get_padded_slot_mapping(self, num_tokens: int, + def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor): # After pcp allgather and restore, there are padded tokens in kv, # so we need pad slotmapping for alignment. - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens * - self. - pcp_world_size] + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens_padded * self.pcp_world_size] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * self.pcp_world_size] pcp_padded_slot_mapping.fill_(-1) - pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + pcp_padded_slot_mapping[:num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping return pcp_padded_slot_mapping def get_restore_hidden_states(