diff --git a/pyproject.toml b/pyproject.toml index fca3e0f3..0baf2470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,9 +54,7 @@ exclude = [ # (7) "vllm_ascend/quantization/**", "vllm_ascend/sample/*.py", - "vllm_ascend/worker/v2/**", "vllm_ascend/worker/block_table.py", - "vllm_ascend/worker/npu_input_batch.py", # (8) "vllm_ascend/ops/__init__.py", "vllm_ascend/ops/activation.py", @@ -65,13 +63,9 @@ exclude = [ "vllm_ascend/ops/mla.py", "vllm_ascend/ops/mm_encoder_attention.py", "vllm_ascend/ops/register_custom_ops.py", - "vllm_ascend/ops/rotary_embedding.py", "vllm_ascend/ops/vocab_parallel_embedding.py", "vllm_ascend/ops/weight_prefetch.py", "vllm_ascend/spec_decode/**", - # (9) - "vllm_ascend/worker/model_runner_v1.py", - "vllm_ascend/worker/pcp_utils.py", # (10) "vllm_ascend/ops/*linear*.py", "vllm_ascend/worker/worker.py", @@ -79,6 +73,9 @@ exclude = [ "vllm_ascend/distributed/utils.py", "vllm_ascend/xlite/*.py", "vllm_ascend/patch/worker/patch_*.py", + "vllm_ascend/worker/v2/**", + "vllm_ascend/worker/npu_input_batch.py", + "vllm_ascend/ops/rotary_embedding.py", # (11) "vllm_ascend/ops/fused_moe/**", ] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index df92ecdd..f7718ba8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,24 +24,19 @@ from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from dataclasses import dataclass from multiprocessing import Manager -from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union, TypeAlias, Tuple +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias import numpy as np import torch import torch.distributed as dist import torch.nn as nn from vllm.attention.layer import Attention, MLAAttention -from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config) from vllm.compilation.cuda_graph import CUDAGraphStat -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) -from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group, - get_pcp_group, get_pp_group, - get_tp_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_dcp_group, get_dp_group, get_pcp_group, get_pp_group, get_tp_group from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -56,15 +51,26 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.selector import get_attn_backend # type: ignore from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, UniformTypeKVCacheSpecs) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - ECConnectorOutput, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, SamplerOutput, - make_empty_encoder_model_runner_output) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + ECConnectorOutput, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + SamplerOutput, + make_empty_encoder_model_runner_output, +) from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler @@ -73,28 +79,29 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext -from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput, - GPUModelRunner) -from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput -from vllm.v1.worker.utils import AttentionGroup +from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner from vllm.v1.worker.ubatch_utils import ( UBatchSlices, maybe_create_ubatch_slices, ) +from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_paged_attention + # yapf conflicts with isort for this block # yapf: disable -from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, - set_draft_graph_params, - set_graph_params, - update_full_graph_params) +from vllm_ascend.compilation.acl_graph import ( + ACLGraphWrapper, + set_draft_graph_params, + set_graph_params, + update_full_graph_params, +) + # yapf: enable from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor -from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ - D2DExpertWeightLoader +from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register @@ -105,17 +112,28 @@ from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.utils import (AscendDeviceType, - enable_sp, get_ascend_device_type, - is_drafter_moe_model, is_moe_model, - lmhead_tp_enable, maybe_trans_nz, - set_weight_prefetch_method) +from vllm_ascend.utils import ( + AscendDeviceType, + enable_sp, + get_ascend_device_type, + is_drafter_moe_model, + is_moe_model, + lmhead_tp_enable, + maybe_trans_nz, + set_weight_prefetch_method, +) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.pcp_utils import PCPManager from vllm_ascend.ascend_forward_context import ( # isort: skip - MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, - set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity) + MoECommType, + get_mc2_tokens_capacity, + select_moe_comm_method, + set_ascend_forward_context, + set_mc2_mask, + set_mc2_tokens_capacity, +) + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -158,8 +176,7 @@ def graph_capture(device: torch.device): in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ - graph_capture_context = GraphCaptureContext( - torch.npu.Stream(device=device)) + graph_capture_context = GraphCaptureContext(torch.npu.Stream(device=device)) stream = graph_capture_context.stream # we use nullcontext now @@ -197,13 +214,14 @@ class ExecuteModelState(NamedTuple): class NPUModelRunner(GPUModelRunner): - def __init__(self, vllm_config: VllmConfig, device: torch.device): # TODO(qcs): These manual pad and unpad for GPUModelRunner are # used to expand some buffers, which need to be reverted after # the following PR is merged: # https://github.com/vllm-project/vllm/pull/28988 - max_pcp_pad_tokens = vllm_config.parallel_config.prefill_context_parallel_size * 2 * vllm_config.scheduler_config.max_num_seqs + max_pcp_pad_tokens = ( + vllm_config.parallel_config.prefill_context_parallel_size * 2 * vllm_config.scheduler_config.max_num_seqs + ) vllm_config.scheduler_config.max_num_batched_tokens += max_pcp_pad_tokens with _torch_cuda_wrapper(): super().__init__(vllm_config, device) @@ -216,8 +234,7 @@ class NPUModelRunner(GPUModelRunner): self.dcp_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 except Exception: self.dcp_size = 1 self.dcp_rank = 0 @@ -227,8 +244,7 @@ class NPUModelRunner(GPUModelRunner): self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs max_buffer_num_tokens = self.max_num_tokens if self.pcp_size * self.dcp_size > 1: - max_buffer_num_tokens = (self.max_num_tokens + - self.max_num_reqs * 2 * self.pcp_size) + max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size self.pcp_manager = PCPManager( self.pcp_size, self.pcp_rank, @@ -242,10 +258,8 @@ class NPUModelRunner(GPUModelRunner): self.pin_memory, ) # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this - self.input_ids = self._make_buffer(max_buffer_num_tokens, - dtype=torch.int32) - self.positions = self._make_buffer(max_buffer_num_tokens, - dtype=torch.int64) + self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) self.sampler = AscendSampler() self.attn_state: AscendAttentionState | None = None @@ -258,10 +272,10 @@ class NPUModelRunner(GPUModelRunner): if dump_cfg is not None: if self.model_config.enforce_eager: from msprobe.pytorch import PrecisionDebugger + self.debugger = PrecisionDebugger(dump_cfg) else: - raise RuntimeError( - "Dumping/debugging only works in eager mode.") + raise RuntimeError("Dumping/debugging only works in eager mode.") # use_hybrid_blocks: if hybrid blocks is used. self.use_hybrid_blocks: bool = False self.need_accepted_tokens: bool = False @@ -269,8 +283,7 @@ class NPUModelRunner(GPUModelRunner): self.is_multimodal_model = self.model_config.is_multimodal_model self.block_size = vllm_config.cache_config.block_size # Set up Attention - self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, - "index_topk") + self.use_sparse = hasattr(self.vllm_config.model_config.hf_text_config, "index_topk") self.attn_backend = get_attn_backend( 0, self.dtype, @@ -278,8 +291,8 @@ class NPUModelRunner(GPUModelRunner): self.block_size, use_mla=self.model_config.use_mla, use_sparse=self.use_sparse, - use_mm_prefix=self.model_config is not None - and self.model_config.is_mm_prefix_lm) + use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm, + ) self._set_up_drafter() @@ -290,14 +303,10 @@ class NPUModelRunner(GPUModelRunner): self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer - set_cos_and_sin(vllm_config, self.max_num_reqs, - self.uniform_decode_query_len, self.dtype, self.device) - set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, - self.uniform_decode_query_len) + set_cos_and_sin(vllm_config, self.max_num_reqs, self.uniform_decode_query_len, self.dtype, self.device) + set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, self.uniform_decode_query_len) set_mc2_mask(vllm_config, self.device) - self.decode_threshold = 1 + ( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0) + self.decode_threshold = 1 + (self.speculative_config.num_speculative_tokens if self.speculative_config else 0) self.use_aclgraph = self._use_aclgraph() @@ -308,17 +317,10 @@ class NPUModelRunner(GPUModelRunner): self.policy_type = eplb_config.eplb_policy_type self.eplb_loader = D2DExpertWeightLoader() self.manager = Manager() - self.shared_dict = self.manager.dict({ - "expert_map": None, - "moe_load": None, - "expert_maps": None - }) - self.eplb_process = EplbProcess(shared_dict=self.shared_dict, - policy_type=self.policy_type, - enable_d2d=True) + self.shared_dict = self.manager.dict({"expert_map": None, "moe_load": None, "expert_maps": None}) + self.eplb_process = EplbProcess(shared_dict=self.shared_dict, policy_type=self.policy_type, enable_d2d=True) self.process = self.eplb_process._launch_process() - self.eplb_updator = EplbUpdator(eplb_config, self.eplb_loader, - self.eplb_process, self.process) + self.eplb_updator = EplbUpdator(eplb_config, self.eplb_loader, self.eplb_process, self.process) # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside # `initialize_kv_cache` based on the kv cache config. However, as in @@ -330,8 +332,7 @@ class NPUModelRunner(GPUModelRunner): # the block_sizes in the kv cache config. self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, - self.max_encoder_len), + max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -340,18 +341,19 @@ class NPUModelRunner(GPUModelRunner): kernel_block_sizes=[[self.cache_config.block_size]], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + self.vllm_config.model_config.logits_processors, + ), is_pooling_model=self.is_pooling_model, num_speculative_tokens=( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), - cp_kv_cache_interleave_size=self.parallel_config. - cp_kv_cache_interleave_size, + self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0 + ), + cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) - self.num_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) # here we use int32 self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), @@ -378,9 +380,9 @@ class NPUModelRunner(GPUModelRunner): def _set_up_drafter(self): # Set up speculative decoding. - self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - SuffixDecodingProposer, - MedusaProposer]] = None + self.drafter: NgramProposer | EagleProposer | MtpProposer | SuffixDecodingProposer | MedusaProposer | None = ( + None + ) self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: @@ -391,22 +393,23 @@ class NPUModelRunner(GPUModelRunner): self.drafter = self._get_drafter() if self.speculative_config.method == "eagle3": assert isinstance(self.drafter, EagleProposer) - self.use_aux_hidden_state_outputs = ( - self.drafter.eagle3_use_aux_hidden_state) + self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state self.rejection_sampler = RejectionSampler(self.sampler) self.actual_seq_lengths_q = list( - range(self.decode_token_per_req, self.max_num_tokens + 1, - self.decode_token_per_req)) - self.discard_request_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req) + ) + self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0 def _get_drafter(self): - return get_spec_decode_method(self.speculative_config.method, - self.vllm_config, self.device, self) + return get_spec_decode_method(self.speculative_config.method, self.vllm_config, self.device, self) def _use_aclgraph(self) -> bool: - return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager + return ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and self.compilation_config.mode == CompilationMode.VLLM_COMPILE + and not self.model_config.enforce_eager + ) def _skip_all_reduce_across_dp_group(self, is_draft_model=False) -> bool: """ @@ -419,8 +422,9 @@ class NPUModelRunner(GPUModelRunner): """ # For dense models, since we don't actually need dp communication, we simply skip it. # This usually happens when main model is moe while eagle draft model is dense. - is_context_moe_model = is_drafter_moe_model(self.vllm_config) if is_draft_model \ - else is_moe_model(self.vllm_config) + is_context_moe_model = ( + is_drafter_moe_model(self.vllm_config) if is_draft_model else is_moe_model(self.vllm_config) + ) if not is_context_moe_model: return True @@ -429,9 +433,7 @@ class NPUModelRunner(GPUModelRunner): return False def needs_mc2(num_tokens: int) -> bool: - return select_moe_comm_method(num_tokens, self.vllm_config) in { - MoECommType.MC2, MoECommType.FUSED_MC2 - } + return select_moe_comm_method(num_tokens, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2} # Determine whether decode must use MC2. Use max cudagraph capture size # if available, otherwise use the maximal uniform decode token count. @@ -443,21 +445,15 @@ class NPUModelRunner(GPUModelRunner): # For prefill, use the scheduler's max_num_batched_tokens for a single # batch. - prefill_must_use_mc2 = needs_mc2( - self.vllm_config.scheduler_config.max_num_batched_tokens) + prefill_must_use_mc2 = needs_mc2(self.vllm_config.scheduler_config.max_num_batched_tokens) # Skip all-reduce if decode requires MC2 and either prefill also # requires MC2 or recompute-based scheduler is enabled. - return decode_must_use_mc2 and ( - prefill_must_use_mc2 - or self.ascend_config.recompute_scheduler_enable) + return decode_must_use_mc2 and (prefill_must_use_mc2 or self.ascend_config.recompute_scheduler_enable) def _sync_metadata_across_dp( - self, - num_tokens: int, - with_prefill: bool = False, - is_draft_model: bool = False - ) -> tuple[int, Optional[torch.Tensor], bool]: + self, num_tokens: int, with_prefill: bool = False, is_draft_model: bool = False + ) -> tuple[int, torch.Tensor | None, bool]: # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in # our case, we still need to sync the other two flags as well. So we need to # include them in the all_reduce operation, and more over, we CANNOT skip it @@ -468,22 +464,15 @@ class NPUModelRunner(GPUModelRunner): return num_tokens, None, with_prefill if self._skip_all_reduce_across_dp_group(is_draft_model): - num_tokens_after_padding = torch.tensor([num_tokens] * - self.dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor([num_tokens] * self.dp_size, device="cpu", dtype=torch.int32) return num_tokens, num_tokens_after_padding, with_prefill # Sync num_tokens, with_prefill across dp ranks - num_tokens_tensor = torch.tensor([ - num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) - ], - dtype=torch.int32, - device="cpu") + num_tokens_tensor = torch.tensor( + [num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)], dtype=torch.int32, device="cpu" + ) - flags_tensor = torch.tensor([int(with_prefill)], - dtype=torch.int32, - device="cpu") + flags_tensor = torch.tensor([int(with_prefill)], dtype=torch.int32, device="cpu") packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) # use cpu_group to avoid cpu synchronization issue. @@ -497,10 +486,7 @@ class NPUModelRunner(GPUModelRunner): global_with_prefill = bool(synced_flags[0]) # Create a tensor for num_tokens_after_padding - num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * - self.dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * self.dp_size, device="cpu", dtype=torch.int32) return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill @@ -514,9 +500,7 @@ class NPUModelRunner(GPUModelRunner): self, scheduler_output: "SchedulerOutput", num_scheduled_tokens: np.ndarray, - ) -> tuple[ - torch.Tensor, - SpecDecodeMetadata | None]: + ) -> tuple[torch.Tensor, SpecDecodeMetadata | None]: """ :return: tuple[ logits_indices, spec_decode_metadata, @@ -537,33 +521,28 @@ class NPUModelRunner(GPUModelRunner): if not scheduler_output.scheduled_spec_decode_tokens: num_valid_tokens = num_scheduled_tokens else: - num_valid_tokens = np.array([ - scheduler_output.num_scheduled_tokens[i] - - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for i in self.input_batch.req_ids - ], dtype=np.int32) - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) + num_valid_tokens = np.array( + [ + scheduler_output.num_scheduled_tokens[i] + - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for i in self.input_batch.req_ids + ], + dtype=np.int32, + ) + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) self.attn_state = attn_state # type: ignore # Determine if it's a splitfuse batch - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] + with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] self.with_prefill = with_prefill # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) if self.pcp_size * self.dcp_size > 1: self.pcp_manager.init_batch_info( @@ -584,20 +563,18 @@ class NPUModelRunner(GPUModelRunner): cu_num_tokens, self._draft_token_ids, # type: ignore[has-type] scheduler_output, - self.num_spec_tokens) + self.num_spec_tokens, + ) if self.pcp_size > 1: - num_scheduled_tokens[: - num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( - num_scheduled_tokens[:num_reqs], - self.arange_np, - ) + num_scheduled_tokens[:num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs], + self.arange_np, + ) # Re-update after PCP split sequences. total_num_scheduled_tokens = sum(num_scheduled_tokens) - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - cu_num_tokens, _ = self._get_cumsum_and_arange( - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) positions_np = self.positions.np[:total_num_scheduled_tokens] np.add( self.input_batch.num_computed_tokens_cpu[req_indices], @@ -610,30 +587,28 @@ class NPUModelRunner(GPUModelRunner): # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] token_indices_tensor = torch.from_numpy(token_indices) # Prepare input_ids. # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) if self.enable_prompt_embeds: is_token_ids = self.input_batch.is_token_ids_tensor.flatten() torch.index_select( - is_token_ids, - 0, - token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + is_token_ids, 0, token_indices_tensor, out=self.is_token_ids.cpu[:total_num_scheduled_tokens] + ) # Because we did not pre-allocate a massive prompt_embeds CPU tensor on # the InputBatch, we need to fill in the prompt embeds into the expected # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. - if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or - self.enable_prompt_embeds): + if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or self.enable_prompt_embeds): output_idx = 0 for req_idx in range(num_reqs): num_sched = num_scheduled_tokens[req_idx] @@ -662,31 +637,28 @@ class NPUModelRunner(GPUModelRunner): actual_num_sched = actual_end - start_pos if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) + self.inputs_embeds.cpu[output_idx : output_idx + actual_num_sched].copy_( + req_embeds[start_pos:actual_end] + ) output_idx += num_sched self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # NOTE: Due to the FIA operator limitation, here we pad so that hidden_states.shape[0] # and self.query_start_loc[num_reqs_padded] are equal - self.query_start_loc.np[num_reqs + 1:] = (self.arange_np[1:self.max_num_reqs + 1 - num_reqs] - * self.uniform_decode_query_len + cu_num_tokens[-1]) + self.query_start_loc.np[num_reqs + 1 :] = ( + self.arange_np[1 : self.max_num_reqs + 1 - num_reqs] * self.uniform_decode_query_len + cu_num_tokens[-1] + ) self.query_start_loc.copy_to_gpu() - self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens self.seq_lens.copy_to_gpu() self.seq_lens.gpu[num_reqs:].fill_(0) # Copy the tensors to the NPU. - self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, - cu_num_tokens) + self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, cu_num_tokens) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -709,29 +681,24 @@ class NPUModelRunner(GPUModelRunner): # Record the index of requests that should not be sampled, # so that we could clear the sampled tokens before returning - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) base_num_reqs = self.input_batch.num_reqs num_reqs = base_num_reqs if self.pcp_size > 1: # while pcp > 1, we need the original num_scheduled_tokens before split # to calculate discard_requests_mask - tokens_original = [ - scheduler_output.num_scheduled_tokens[i] for i in self.input_batch.req_ids - ] - original_seq_lens_np = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - np.array(tokens_original, dtype=np.int32)) + tokens_original = [scheduler_output.num_scheduled_tokens[i] for i in self.input_batch.req_ids] + original_seq_lens_np = self.input_batch.num_computed_tokens_cpu[:num_reqs] + np.array( + tokens_original, dtype=np.int32 + ) discard_requests_mask = original_seq_lens_np < num_tokens_np else: discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = discard_request_indices self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -745,10 +712,9 @@ class NPUModelRunner(GPUModelRunner): num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) if self.pcp_size * self.dcp_size > 1: logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens) - logits_indices = logits_indices.pin_memory().to( - self.device, non_blocking=True) + logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True) else: - logits_indices = self.query_start_loc.gpu[1:num_reqs + 1] - 1 + logits_indices = self.query_start_loc.gpu[1 : num_reqs + 1] - 1 else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all @@ -763,15 +729,19 @@ class NPUModelRunner(GPUModelRunner): ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens, - num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs] - if self.pcp_size > 1 else None) + num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs] if self.pcp_size > 1 else None, + ) logits_indices = spec_decode_metadata.logits_indices num_sampled_tokens = num_draft_tokens + 1 @@ -784,35 +754,27 @@ class NPUModelRunner(GPUModelRunner): # Hot-Swap lora model if self.lora_config: - assert ( - np.sum(num_sampled_tokens) - <= self.vllm_config.scheduler_config.max_num_batched_tokens - ) - self.set_active_loras( - self.input_batch, num_scheduled_tokens, num_sampled_tokens - ) + assert np.sum(num_sampled_tokens) <= self.vllm_config.scheduler_config.max_num_batched_tokens + self.set_active_loras(self.input_batch, num_scheduled_tokens, num_sampled_tokens) if lmhead_tp_enable(): max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len - logits_indices = nn.functional.pad( - logits_indices, - (0, max_num_reqs_across_dp - logits_indices.shape[0])) + logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) return logits_indices, spec_decode_metadata - def _build_attn_state(self, num_reqs, num_scheduled_tokens, - num_valid_tokens): + def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly - if self.speculative_config and self.speculative_config.method == 'mtp': + if self.speculative_config and self.speculative_config.method == "mtp": # SpecDecoding now supports seq_len=1 and seq_len=2 # In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1 attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): - if self.speculative_config and self.speculative_config.method == 'mtp': + if self.speculative_config and self.speculative_config.method == "mtp": attn_state = AscendAttentionState.SpecDecoding else: attn_state = AscendAttentionState.ChunkedPrefill @@ -846,13 +808,11 @@ class NPUModelRunner(GPUModelRunner): cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) total_num_sampled_tokens = cu_num_sampled_tokens[-1] # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, - num_sampled_tokens) + cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] - logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + logits_indices = np.repeat(cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -860,13 +820,9 @@ class NPUModelRunner(GPUModelRunner): # update logits_indices after getting draft_token_ids from ori logits_indices if self.pcp_size > 1: cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads - logits_indices_pcp = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, - num_sampled_tokens) + logits_indices_pcp = np.repeat(cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) logits_indices_pcp += arange - logits_indices_pcp = torch.from_numpy( - logits_indices_pcp).pin_memory().to(self.device, - non_blocking=True) + logits_indices_pcp = torch.from_numpy(logits_indices_pcp).pin_memory().to(self.device, non_blocking=True) # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 @@ -876,31 +832,20 @@ class NPUModelRunner(GPUModelRunner): cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) total_num_draft_tokens = cu_num_draft_tokens[-1] # [0, 0, 0, 3, 3, 5] - cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, - num_draft_tokens) + cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, num_draft_tokens) # [0, 1, 2, 0, 1, 0] arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets # [0, 0, 0, 5, 5, 9] - target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + target_logits_indices = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> NPU copy. - cu_num_draft_tokens = ( - torch.from_numpy(cu_num_draft_tokens).pin_memory().to( - self.device, non_blocking=True)) - cu_num_sampled_tokens = ( - torch.from_numpy(cu_num_sampled_tokens).pin_memory().to( - self.device, non_blocking=True)) - logits_indices = (torch.from_numpy(logits_indices).pin_memory().to( - self.device, non_blocking=True)) - target_logits_indices = ( - torch.from_numpy(target_logits_indices).pin_memory().to( - self.device, non_blocking=True)) - bonus_logits_indices = torch.from_numpy( - bonus_logits_indices).pin_memory().to(self.device, - non_blocking=True) + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(self.device, non_blocking=True) + cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).pin_memory().to(self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device, non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(self.device, non_blocking=True) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] @@ -931,21 +876,27 @@ class NPUModelRunner(GPUModelRunner): hidden_states: torch.Tensor, attn_metadata: list[dict[str, Any]] | dict[str, Any], aux_hidden_states: torch.Tensor = None, - sample_hidden_states: torch.Tensor = None - ) -> Optional[list[list[int]]]: + sample_hidden_states: torch.Tensor = None, + ) -> list[list[int]] | None: if not self.drafter: # Speculative decoding is not enabled. draft_token_ids = None else: if self.speculative_config.method in ("suffix", "ngram"): draft_token_ids = self.drafter.generate_token_ids( - valid_sampled_token_ids, sampling_metadata, - scheduler_output, spec_decode_metadata, positions, - num_scheduled_tokens, hidden_states, aux_hidden_states) + valid_sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + num_scheduled_tokens, + hidden_states, + aux_hidden_states, + ) elif isinstance(self.drafter, MedusaProposer): draft_token_ids = self.drafter.generate_token_ids( - valid_sampled_token_ids, sampling_metadata, - spec_decode_metadata, sample_hidden_states) + valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states + ) elif self.speculative_config.use_eagle(): common_attn_metadata = spec_decode_common_attn_metadata sampled_token_ids = valid_sampled_token_ids @@ -954,33 +905,31 @@ class NPUModelRunner(GPUModelRunner): # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ - "padded-batch is disabled." + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list whenpadded-batch is disabled." + ) assert self.drafter is not None next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, - scheduler_output.num_scheduled_tokens) + sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens + ) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ - "padded-batch is enabled." + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled." + ) assert self.drafter is not None - next_token_ids, valid_sampled_tokens_count = \ - self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_indices.gpu, - self.num_discarded_requests - ) - self._copy_valid_sampled_token_count( - next_token_ids, valid_sampled_tokens_count) + next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size * self.dcp_size > 1: @@ -989,8 +938,6 @@ class NPUModelRunner(GPUModelRunner): query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu num_reqs = self.input_batch.num_reqs - ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ - query_start_loc_pcp_full_cpu[:num_reqs] num_prefill_reqs = self.pcp_manager.num_prefill_reqs num_decode_reqs = self.pcp_manager.num_decode_reqs else: @@ -1000,52 +947,43 @@ class NPUModelRunner(GPUModelRunner): if spec_decode_metadata is None: # update pcp related params if self.pcp_size > 1: - token_indices_to_sample = \ - query_start_loc_pcp_full[1:num_reqs + 1] - 1 - target_token_ids = input_ids_pcp_full[: - num_scheduled_tokens] + token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1 + target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] target_positions = self._get_positions(num_scheduled_tokens) target_hidden_states = hidden_states else: token_indices_to_sample = None # input_ids can be None for multimodal models. - target_token_ids = self.input_ids.gpu[: - num_scheduled_tokens] + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([ - h[:num_scheduled_tokens] - for h in aux_hidden_states - ], - dim=-1) + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: - target_hidden_states = hidden_states[: - num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.pcp_size > 1: assert common_attn_metadata is not None - common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \ - query_start_loc_pcp_full_cpu[:num_reqs + 1] + common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[ + : num_reqs + 1 + ] assert common_attn_metadata is not None - common_attn_metadata.query_start_loc[:num_reqs + 1] = \ - query_start_loc_pcp_full[:num_reqs + 1] + common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1] if self.vllm_config.speculative_config.disable_padded_drafter_batch: # NOTE: Currently, MTP-fullgraph is incompatibility with pcp token_indices_to_sample = None assert self.drafter is not None - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens + ) else: assert self.drafter is not None - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ - self.drafter.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count) + common_attn_metadata, token_indices, token_indices_to_sample = ( + self.drafter.prepare_inputs_padded( + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) + ) if self.pcp_size > 1: target_token_ids = input_ids_pcp_full[token_indices] target_positions = positions @@ -1054,9 +992,7 @@ class NPUModelRunner(GPUModelRunner): target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], - dim=-1) + target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] assert self.drafter is not None @@ -1077,8 +1013,7 @@ class NPUModelRunner(GPUModelRunner): ) else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") return draft_token_ids @@ -1086,11 +1021,10 @@ class NPUModelRunner(GPUModelRunner): def execute_model( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors | None = None, ) -> ModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: - raise RuntimeError("State error: sample_tokens() must be called " - "after execute_model() returns None.") + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") # self._draft_token_ids is None when `input_fits_in_drafter=False` # and there is no draft tokens scheduled. so it need to update the # spec_decoding info in scheduler_output with async_scheduling. @@ -1099,9 +1033,7 @@ class NPUModelRunner(GPUModelRunner): # TODO(Ronald1995): deepcopy is expensive when there is a large # number of requests, optimize it later. if ( - self.use_async_scheduling - and self.num_spec_tokens - and self._draft_token_ids is None # type: ignore[has-type] + self.use_async_scheduling and self.num_spec_tokens and self._draft_token_ids is None # type: ignore[has-type] ): scheduler_output = deepcopy(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1120,8 +1052,7 @@ class NPUModelRunner(GPUModelRunner): if not num_scheduled_tokens: if ( - self.parallel_config.distributed_executor_backend - == "external_launcher" + self.parallel_config.distributed_executor_backend == "external_launcher" and self.parallel_config.data_parallel_size > 1 ): # this is a corner case when both external launcher @@ -1134,9 +1065,7 @@ class NPUModelRunner(GPUModelRunner): if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward( - scheduler_output, self.vllm_config - ) + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) if self.cache_config.kv_sharing_fast_prefill: assert not self.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " @@ -1194,9 +1123,7 @@ class NPUModelRunner(GPUModelRunner): ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( should_ubatch, num_scheduled_tokens_np, @@ -1210,20 +1137,18 @@ class NPUModelRunner(GPUModelRunner): use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices - (attn_metadata, spec_decode_common_attn_metadata) = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - num_scheduled_tokens_np=num_scheduled_tokens_np, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - ) + (attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + num_scheduled_tokens_np=num_scheduled_tokens_np, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) ( @@ -1233,9 +1158,7 @@ class NPUModelRunner(GPUModelRunner): intermediate_tensors, model_kwargs, ec_connector_output, - ) = self._preprocess( - scheduler_output, num_tokens_padded, intermediate_tensors - ) + ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) # update global cos, sin update_cos_sin(positions) # Set cudagraph mode to none if calc_kv_scales is true. @@ -1248,9 +1171,7 @@ class NPUModelRunner(GPUModelRunner): # prevent debugger is None if self.debugger is not None: dbg_cfg = getattr(self.debugger, "config", None) - dump_level = str( - getattr(dbg_cfg, "level", - "L1")).upper() if dbg_cfg is not None else "L1" + dump_level = str(getattr(dbg_cfg, "level", "L1")).upper() if dbg_cfg is not None else "L1" if dump_level in ("L0", "MIX"): self.debugger.start(model=self.model) else: @@ -1259,41 +1180,38 @@ class NPUModelRunner(GPUModelRunner): self.sampler.do_async_exponential( b_s=logits_indices.shape[0], head_dim=self.model_config.get_vocab_size(), - generators=self.input_batch.sampling_metadata.generators) + generators=self.input_batch.sampling_metadata.generators, + ) # Encoder-decoder models can only compile the pure decode steps where no # encoder inputs are present. Use eager for the first pass. num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs) - has_encoder_input = ( - self.model_config.is_encoder_decoder and num_encoder_reqs > 0 - ) + has_encoder_input = self.model_config.is_encoder_decoder and num_encoder_reqs > 0 # Run forward pass - with record_function_or_nullcontext("forward"): - with ( - set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens_padded, - num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=cudagraph_mode, - batch_descriptor=batch_desc, - num_actual_tokens=scheduler_output. - total_num_scheduled_tokens, - model_instance=self.model, - skip_compiled=has_encoder_input), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, - ): - hidden_states = self._model_forward( - num_tokens_padded, input_ids, positions, - intermediate_tensors, inputs_embeds, **model_kwargs) + with ( + record_function_or_nullcontext("forward"), + set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + model_instance=self.model, + skip_compiled=has_encoder_input, + ), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + hidden_states = self._model_forward( + num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs + ) with record_function_or_nullcontext("post process"): if self.pcp_size > 1: # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. - hidden_states = self.pcp_manager.get_restore_hidden_states( - hidden_states - ) + hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states) aux_hidden_states = None if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = hidden_states @@ -1328,8 +1246,7 @@ class NPUModelRunner(GPUModelRunner): if not get_pp_group().is_last_rank: sample_hidden_states = hidden_states[logits_indices] - get_pp_group().send_tensor_dict( - hidden_states.tensors, all_gather_group=get_tp_group()) + get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) logits = None else: sample_hidden_states = hidden_states[logits_indices] @@ -1344,7 +1261,6 @@ class NPUModelRunner(GPUModelRunner): assert broadcasted is not None logits = broadcasted["logits"] - # Apply structured output bitmasks if present self.execute_model_state = ExecuteModelState( scheduler_output, @@ -1405,8 +1321,7 @@ class NPUModelRunner(GPUModelRunner): # the apply_grammar_bitmask uses torch.compile to optimize this,ascend does not support it now logits_dtype = logits.dtype logits = logits.to("cpu").float() - apply_grammar_bitmask(scheduler_output, grammar_output, - self.input_batch, logits) + apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits) logits = logits.to(self.device).to(logits_dtype) with record_function_or_nullcontext("sample_token"): @@ -1425,7 +1340,7 @@ class NPUModelRunner(GPUModelRunner): hidden_states, attn_metadata, aux_hidden_states, - sample_hidden_states + sample_hidden_states, ) self._copy_draft_token_ids_to_cpu(scheduler_output) @@ -1447,9 +1362,11 @@ class NPUModelRunner(GPUModelRunner): with record_function_or_nullcontext("draft_token"): if self.speculative_config: - use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_drafter_batch + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) if use_padded_batch_for_eagle: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. @@ -1469,9 +1386,7 @@ class NPUModelRunner(GPUModelRunner): prompt_logprobs_dict=prompt_logprobs_dict, kv_connector_output=kv_connector_output, pooler_output=[], - ec_connector_output=ec_connector_output - if self.supports_mm_inputs - else None, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, cudagraph_stats=cudagraph_stats, ) @@ -1498,14 +1413,14 @@ class NPUModelRunner(GPUModelRunner): sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: if lmhead_tp_enable() and logits is not None: - logits = logits[:self.input_batch.num_reqs] + logits = logits[: self.input_batch.num_reqs] return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) if lmhead_tp_enable() and logits is not None: - logits = logits[:len(spec_decode_metadata.logits_indices)] + logits = logits[: len(spec_decode_metadata.logits_indices)] sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs @@ -1513,8 +1428,7 @@ class NPUModelRunner(GPUModelRunner): sampling_metadata, ) if self.need_accepted_tokens: # TODO remove this if - self._update_states_after_model_execute( - sampler_output.sampled_token_ids) + self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output # TODO: remove this func after eagle_proposer is refactored and @@ -1528,16 +1442,15 @@ class NPUModelRunner(GPUModelRunner): num_scheduled_tokens: int, spec_decode_metadata: SpecDecodeMetadata | None, ) -> tuple[ - LogprobsLists | None, - list[list[int]], - dict[str, LogprobsTensors | None], - list[str], - dict[str, int], - list[int], + LogprobsLists | None, + list[list[int]], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], ]: # TODO: implement PR 28597 from vllm - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] + discard_sampled_tokens_req_indices = self.discard_request_indices.np[: self.num_discarded_requests] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -1583,9 +1496,7 @@ class NPUModelRunner(GPUModelRunner): self.input_batch.prev_sampled_token_ids = sampled_token_ids self.input_batch.prev_req_id_to_index = { - req_id: i - for i, req_id in enumerate(self.input_batch.req_ids) - if i not in invalid_req_indices_set + req_id: i for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set } # Cache the sampled tokens in the model runner, so that the scheduler @@ -1596,9 +1507,7 @@ class NPUModelRunner(GPUModelRunner): req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [ - -1 - ] if req_idx not in invalid_req_indices_set else None + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] @@ -1612,10 +1521,10 @@ class NPUModelRunner(GPUModelRunner): assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}") + f"{self.max_model_len}" + ) - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -1624,9 +1533,11 @@ class NPUModelRunner(GPUModelRunner): req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens) - if not self.use_async_scheduling - and logprobs_tensors is not None else None) + logprobs_lists = ( + logprobs_tensors.tolists(cu_num_tokens) + if not self.use_async_scheduling and logprobs_tensors is not None + else None + ) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -1656,18 +1567,16 @@ class NPUModelRunner(GPUModelRunner): # all-gather a list of hidden-states in sp scene @staticmethod def _all_gather_hidden_states_list(hidden_states_list): - return [ - NPUModelRunner._all_gather_hidden_states(hidden_states) - for hidden_states in hidden_states_list - ] + return [NPUModelRunner._all_gather_hidden_states(hidden_states) for hidden_states in hidden_states_list] # all-gather hidden-states in last layer with aux-hidden-states in sp scene @staticmethod def _all_gather_hidden_states_and_aux(hidden_states): if isinstance(hidden_states, tuple): - return (NPUModelRunner._all_gather_hidden_states(hidden_states[0]), - NPUModelRunner._all_gather_hidden_states_list( - hidden_states[1])) + return ( + NPUModelRunner._all_gather_hidden_states(hidden_states[0]), + NPUModelRunner._all_gather_hidden_states_list(hidden_states[1]), + ) return NPUModelRunner._all_gather_hidden_states(hidden_states) def _model_forward( @@ -1677,26 +1586,35 @@ class NPUModelRunner(GPUModelRunner): positions: torch.Tensor | None = None, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, - **model_kwargs: dict[str, Any],): + **model_kwargs: dict[str, Any], + ): assert self.model is not None hidden_states = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs) + **model_kwargs, + ) forward_context = get_forward_context() assert forward_context is not None - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ - not forward_context.capturing and not self.use_sparse: + if ( + forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL + and not forward_context.capturing + and not self.use_sparse + ): assert positions is not None - update_full_graph_params(self.attn_backend, self.update_stream, forward_context, - num_tokens_padded, self.vllm_config, - self.speculative_config, positions.shape[0]) - if get_forward_context().sp_enabled and not isinstance( - hidden_states, IntermediateTensors): - hidden_states = self._all_gather_hidden_states_and_aux( - hidden_states) + update_full_graph_params( + self.attn_backend, + self.update_stream, + forward_context, + num_tokens_padded, + self.vllm_config, + self.speculative_config, + positions.shape[0], + ) + if get_forward_context().sp_enabled and not isinstance(hidden_states, IntermediateTensors): + hidden_states = self._all_gather_hidden_states_and_aux(hidden_states) return hidden_states def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: @@ -1744,10 +1662,7 @@ class NPUModelRunner(GPUModelRunner): return False, None, cudagraph_mode if self._skip_all_reduce_across_dp_group(): - num_tokens_after_padding = torch.tensor([num_tokens_padded] * - self.dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor([num_tokens_padded] * self.dp_size, device="cpu", dtype=torch.int32) return False, num_tokens_after_padding, cudagraph_mode tensor = torch.zeros(2, self.dp_size, device="cpu", dtype=torch.int32) @@ -1780,46 +1695,46 @@ class NPUModelRunner(GPUModelRunner): force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, num_encoder_reqs: int = 0, - ) -> tuple[CUDAGraphMode, BatchDescriptor, bool, - torch.Tensor | None, CUDAGraphStat | None]: - + ) -> tuple[CUDAGraphMode, BatchDescriptor, bool, torch.Tensor | None, CUDAGraphStat | None]: num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) uniform_decode = ( - ((max_num_scheduled_tokens == self.uniform_decode_query_len) and - (num_tokens == max_num_scheduled_tokens * num_reqs)) - if force_uniform_decode is None else force_uniform_decode) + ( + (max_num_scheduled_tokens == self.uniform_decode_query_len) + and (num_tokens == max_num_scheduled_tokens * num_reqs) + ) + if force_uniform_decode is None + else force_uniform_decode + ) # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output # is present). Also, chunked-prefill is disabled, so batch are uniform. - has_encoder_output = (self.model_config.is_encoder_decoder - and num_encoder_reqs > 0) - has_lora = (len(self.input_batch.lora_id_to_lora_request) > 0 - if force_has_lora is None else force_has_lora) + has_encoder_output = self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 if force_has_lora is None else force_has_lora # ruff: noqa: E731 dispatch_cudagraph = ( - lambda num_tokens, disable_full: self.cudagraph_dispatcher. - dispatch( + lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, uniform_decode=uniform_decode, disable_full=disable_full, - ) if not force_eager else - (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))) - cudagraph_mode, batch_descriptor = dispatch_cudagraph( - num_tokens_padded, use_cascade_attn or has_encoder_output) + ) + if not force_eager + else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) + ) + cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded, use_cascade_attn or has_encoder_output) num_tokens_padded = batch_descriptor.num_tokens if enable_sp(self.vllm_config): - assert (batch_descriptor.num_tokens % - self.vllm_config.parallel_config.tensor_parallel_size == 0 - ), ("Sequence parallelism requires num_tokens to be " - "a multiple of tensor parallel size") + assert batch_descriptor.num_tokens % self.vllm_config.parallel_config.tensor_parallel_size == 0, ( + "Sequence parallelism requires num_tokens to be a multiple of tensor parallel size" + ) # Extra coordination when running data-parallel since we need to coordinate # across ranks should_ubatch, num_tokens_across_dp = False, None if self.vllm_config.parallel_config.data_parallel_size > 1: - _, num_tokens_across_dp, synced_cudagraph_mode = self._sync_batch_across_dp(num_tokens_padded=num_tokens_padded, - cudagraph_mode=cudagraph_mode.value, - ) + _, num_tokens_across_dp, synced_cudagraph_mode = self._sync_batch_across_dp( + num_tokens_padded=num_tokens_padded, + cudagraph_mode=cudagraph_mode.value, + ) # Extract DP padding if there is any if num_tokens_across_dp is not None: @@ -1828,7 +1743,8 @@ class NPUModelRunner(GPUModelRunner): # Re-dispatch with DP padding cudagraph_mode, batch_descriptor = dispatch_cudagraph( num_tokens_padded, - disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,) + disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value, + ) # Assert to make sure the agreed upon token count is correct otherwise # num_tokens_across_dp will no-longer be valid assert batch_descriptor.num_tokens == num_tokens_padded @@ -1883,8 +1799,7 @@ class NPUModelRunner(GPUModelRunner): else: max_seq_len = self.seq_lens.np[:num_reqs].max().item() if use_spec_decode and self.need_accepted_tokens: - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[:num_reqs] = self.input_batch.num_accepted_tokens_cpu[:num_reqs] self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() @@ -1893,15 +1808,18 @@ class NPUModelRunner(GPUModelRunner): def _get_pcp_metadata(num_tokens): if not self.use_cp: return None - return self.pcp_manager.generate_pcp_metadata(num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np) + return self.pcp_manager.generate_pcp_metadata( + num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np + ) def _get_block_table_and_slot_mapping(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec maybe_pcp_full_tokens = ( - num_tokens_padded if self.pcp_size == 1 else - num_tokens * self.pcp_size - - sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])) + num_tokens_padded + if self.pcp_size == 1 + else num_tokens * self.pcp_size - sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]) + ) if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): blk_table_tensor = torch.zeros( (num_reqs_padded, 1), @@ -1942,9 +1860,7 @@ class NPUModelRunner(GPUModelRunner): # TODO seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], # TODO - num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs_padded], num_reqs=num_reqs_padded, num_actual_tokens=num_tokens, max_query_len=max_query_len, @@ -1962,9 +1878,7 @@ class NPUModelRunner(GPUModelRunner): if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: cm_base.num_logits_indices = logits_indices.size(0) - cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices - ) + cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill(logits_indices) def _build_attn_group_metadata( kv_cache_gid: int, @@ -1975,9 +1889,7 @@ class NPUModelRunner(GPUModelRunner): attn_group = self.attn_groups[kv_cache_gid][attn_gid] builder = attn_group.get_metadata_builder(ubid or 0) cascade_attn_prefix_len = ( - cascade_attn_prefix_lens[kv_cache_gid][attn_gid] - if cascade_attn_prefix_lens - else 0 + cascade_attn_prefix_lens[kv_cache_gid][attn_gid] if cascade_attn_prefix_lens else 0 ) extra_attn_metadata_args = {} @@ -1986,15 +1898,11 @@ class NPUModelRunner(GPUModelRunner): patch_torch_npu_argsort() extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], - num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ - :num_reqs_padded - ], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[:num_reqs_padded], ) if for_cudagraph_capture: - attn_metadata_i = builder.build_for_cudagraph_capture( - common_attn_metadata - ) + attn_metadata_i = builder.build_for_cudagraph_capture(common_attn_metadata) else: attn_metadata_i = builder.build( common_prefix_len=cascade_attn_prefix_len, @@ -2015,8 +1923,7 @@ class NPUModelRunner(GPUModelRunner): # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. spec_decode_common_attn_metadata = None - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups): + for kv_cache_gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups): cm = copy(cm_base) # shallow copy # Basically only the encoder seq_lens, block_table and slot_mapping change # for each kv_cache_group. @@ -2026,9 +1933,7 @@ class NPUModelRunner(GPUModelRunner): num_reqs_padded, ) if kv_cache_gid > 0: - cm.block_table_tensor, cm.slot_mapping = ( - _get_block_table_and_slot_mapping(kv_cache_gid) - ) + cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid) if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: @@ -2064,9 +1969,7 @@ class NPUModelRunner(GPUModelRunner): # Currently the drafter still only uses piecewise cudagraphs (and modifies # the attention metadata in directly), and therefore does not want to use # padded attention metadata. - spec_decode_common_attn_metadata = ( - spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) - ) + spec_decode_common_attn_metadata = spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) return attn_metadata, spec_decode_common_attn_metadata @torch.inference_mode() @@ -2074,7 +1977,7 @@ class NPUModelRunner(GPUModelRunner): self, num_tokens: int, with_prefill: bool = False, - cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, is_profile: bool = False, @@ -2126,26 +2029,23 @@ class NPUModelRunner(GPUModelRunner): self.query_lens = torch.from_numpy(num_scheduled_tokens) num_tokens_unpadded = int(num_scheduled_tokens.sum()) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - _cudagraph_mode, batch_desc, _, num_tokens_across_dp, _ = ( - self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens, - max_num_scheduled_tokens=max_query_len, - use_cascade_attn=False, - allow_microbatching=allow_microbatching, - force_eager=is_profile - or (cudagraph_runtime_mode == CUDAGraphMode.NONE), - # `force_uniform_decode` is used for cudagraph capture; because for - # capturing mixed prefill-decode batches, we sometimes use - # num_tokens == num_reqs which looks like a uniform decode batch to the - # dispatcher; but we actually want to capture a piecewise cudagraph - force_uniform_decode=uniform_decode, - # `force_has_lora` is used for cudagraph capture; because LoRA is - # activated later in the context manager, but we need to know the - # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, - ) + _cudagraph_mode, batch_desc, _, num_tokens_across_dp, _ = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, ) if self.pcp_size * self.dcp_size > 1: self.pcp_manager.init_batch_info( @@ -2160,9 +2060,7 @@ class NPUModelRunner(GPUModelRunner): f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs if num_tokens_across_dp is not None and num_tokens_padded != num_tokens: # pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher num_tokens_across_dp[:] = num_tokens_padded @@ -2174,10 +2072,11 @@ class NPUModelRunner(GPUModelRunner): # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: if create_mixed_batch: - raise NotImplementedError("create_mixed_batch is used for warmup deepgemm, vllm-ascend does not need it") + raise NotImplementedError( + "create_mixed_batch is used for warmup deepgemm, vllm-ascend does not need it" + ) self.attn_state = AscendAttentionState.DecodeOnly - if self.speculative_config and \ - self.speculative_config.method == "mtp": + if self.speculative_config and self.speculative_config.method == "mtp": # `AscendAttentionState.SpecDecoding` is only designed for mla if self.vllm_config.model_config.use_mla: self.attn_state = AscendAttentionState.SpecDecoding @@ -2188,7 +2087,11 @@ class NPUModelRunner(GPUModelRunner): # seq_lens. We use this seq_len only when capturing graph, and still use max_query_len # in inference. This will be removed once npu_fused_infer_attention_score # outperforms _npu_paged_attention on all cases. - seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) else max_query_len # type: ignore[assignment] + seq_lens = ( + SEQ_LEN_WITH_MAX_PA_WORKSPACE + if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) + else max_query_len + ) # type: ignore[assignment] self.seq_lens.np[:num_reqs_padded] = seq_lens self.seq_lens.np[num_reqs_padded:] = 0 self.seq_lens.copy_to_gpu() @@ -2207,16 +2110,13 @@ class NPUModelRunner(GPUModelRunner): ) with self.maybe_dummy_run_with_lora( - self.lora_config, - num_scheduled_tokens, - num_sampled_tokens, + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - if self.is_multimodal_model and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] - elif self.enable_prompt_embeds: + if self.is_multimodal_model and not self.model_config.is_encoder_decoder or self.enable_prompt_embeds: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] else: @@ -2236,33 +2136,27 @@ class NPUModelRunner(GPUModelRunner): if get_pp_group().is_first_rank: intermediate_tensors = None else: - # When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by tp_size; - # otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading to incorrect memory estimation and potentially causing OOM. + # When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by + # tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading + # to incorrect memory estimation and potentially causing OOM. intermediate_tokens = num_tokens_padded if enable_sp(): tp_size = get_tensor_model_parallel_world_size() - intermediate_tokens = (num_tokens_padded + tp_size - - 1) // tp_size + intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size if self.intermediate_tensors is None: max_actual_tokens = self.max_num_tokens if enable_sp(): - max_actual_tokens = (self.max_num_tokens + tp_size - - 1) // tp_size - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=max_actual_tokens, - dtype=self.dtype, - device=self.device)) - intermediate_tensors = IntermediateTensors({ - k: - v[:intermediate_tokens] - for k, v in self.intermediate_tensors.items() - }) + max_actual_tokens = (self.max_num_tokens + tp_size - 1) // tp_size + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=max_actual_tokens, dtype=self.dtype, device=self.device + ) + intermediate_tensors = IntermediateTensors( + {k: v[:intermediate_tokens] for k, v in self.intermediate_tensors.items()} + ) - need_dummy_logits = (not is_profile and lmhead_tp_enable()) + need_dummy_logits = not is_profile and lmhead_tp_enable() max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len - dummy_indices = torch.zeros(max_num_reqs_across_dp, - dtype=torch.int32) + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) def dummy_compute_logits(hidden_states): if not need_dummy_logits: @@ -2272,24 +2166,23 @@ class NPUModelRunner(GPUModelRunner): def dummy_drafter_compute_logits(hidden_states): if not need_dummy_logits or self.drafter is None: return - if hasattr(self.drafter, "model") and hasattr( - self.drafter.model, "compute_logits"): - return self.drafter.model.compute_logits( - hidden_states[dummy_indices]) + if hasattr(self.drafter, "model") and hasattr(self.drafter.model, "compute_logits"): + return self.drafter.model.compute_logits(hidden_states[dummy_indices]) with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens_padded, - num_tokens_across_dp=num_tokens_across_dp, - in_profile_run=is_profile, - num_actual_tokens=num_tokens_padded, - aclgraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_desc, - model_instance=self.model): + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + in_profile_run=is_profile, + num_actual_tokens=num_tokens_padded, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, + model_instance=self.model, + ): outputs = self._model_forward( - num_tokens_padded, input_ids, positions, - intermediate_tensors, inputs_embeds) + num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds + ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2306,7 +2199,8 @@ class NPUModelRunner(GPUModelRunner): batch_descriptor=batch_desc, dummy_compute_logits=dummy_drafter_compute_logits, in_graph_capturing=not force_attention, - is_profile=is_profile) + is_profile=is_profile, + ) if is_profile and self.dynamic_eplb: self.model.clear_all_moe_loads() if self.dynamic_eplb: @@ -2325,10 +2219,8 @@ class NPUModelRunner(GPUModelRunner): # maximum num_tokens. min_tokens_per_req = self.max_num_tokens // self.max_num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs - num_scheduled_tokens_list[ - -1] += self.max_num_tokens % self.max_num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens_list[-1] += self.max_num_tokens % self.max_num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task hidden_states = hidden_states[logit_indices] @@ -2338,17 +2230,15 @@ class NPUModelRunner(GPUModelRunner): def profile_run(self) -> None: self.eplb_warmup() mc2_tokens_capacity = get_mc2_tokens_capacity() - if self.max_num_tokens > mc2_tokens_capacity and \ - select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}: - self._dummy_run(mc2_tokens_capacity, - with_prefill=True, - is_profile=True) + if self.max_num_tokens > mc2_tokens_capacity and select_moe_comm_method( + mc2_tokens_capacity, self.vllm_config + ) in {MoECommType.MC2, MoECommType.FUSED_MC2}: + self._dummy_run(mc2_tokens_capacity, with_prefill=True, is_profile=True) origin_max_num_tokens = self.max_num_tokens # in the pcp scenario, the split sequence needs to be used for profile run # TODO: after the vllm pcp function is launched, this logic needs to be brought up to the community if self.pcp_size > 1: - self.max_num_tokens = math.ceil(self.max_num_tokens / - (self.pcp_size * 2)) * 2 + self.max_num_tokens = math.ceil(self.max_num_tokens / (self.pcp_size * 2)) * 2 super().profile_run() self.max_num_tokens = origin_max_num_tokens @@ -2372,21 +2262,16 @@ class NPUModelRunner(GPUModelRunner): with get_tp_context(self.drafter): self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + self.model.set_aux_hidden_state_layers(self.model.get_eagle3_aux_hidden_state_layers()) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) - logger.info("Loading model weights took %.4f GB", - m.consumed_memory / float(2**30)) + self.model = self.load_lora_model(self.model, self.vllm_config, self.device) + logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) # wrap the model with full graph wrapper if needed. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.update_stream: torch.npu.Stream = torch.npu.Stream() - self.model = ACLGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2401,12 +2286,11 @@ class NPUModelRunner(GPUModelRunner): self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups self.initialize_attn_backend(kv_cache_config) - self.use_hybrid_blocks = (len(self.attn_groups) > 1) + self.use_hybrid_blocks = len(self.attn_groups) > 1 # NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`. - self.need_accepted_tokens = any([ - isinstance(attn_group[0].kv_cache_spec, MambaSpec) - for attn_group in self.attn_groups - ]) + self.need_accepted_tokens = any( + [isinstance(attn_group[0].kv_cache_spec, MambaSpec) for attn_group in self.attn_groups] + ) self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -2414,15 +2298,13 @@ class NPUModelRunner(GPUModelRunner): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) - def _align_memory(self, tensor: torch.Tensor, - alignment: int) -> torch.Tensor: + def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor: data_ptr = tensor.data_ptr() aligned_addr = (data_ptr + alignment - 1) // alignment * alignment offset = (aligned_addr - data_ptr) // tensor.element_size() - return tensor[int(offset):] + return tensor[int(offset) :] - def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -2435,25 +2317,20 @@ class NPUModelRunner(GPUModelRunner): # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] from vllm.v1.worker.utils import bind_kv_cache + num_attn_module = 2 if self.model_config.hf_text_config.model_type == "longcat_flash" else 1 - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, num_attn_module) + bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches, num_attn_module) return kv_caches - def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -2468,45 +2345,39 @@ class NPUModelRunner(GPUModelRunner): corresponding memory buffer for KV cache. dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names to their corresponding memory buffer for K cache and V cache. - """ + """ # init kv cache tensors - kv_cache_raw_tensors: dict[str, Union[torch.Tensor, - Optional[torch.Tensor]]] = {} + kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: # TODO: REFACTOR ME to sharing hybrid cache for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] - if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys( - ): + if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors: # for mamba linear attention if self.vllm_config.kv_transfer_config is None: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=self.device) else: cache_size_aligned = kv_cache_tensor.size + alignment - tensor = torch.zeros(cache_size_aligned, - dtype=torch.int8, - device=self.device) - tensor = self._align_memory( - tensor, alignment)[:kv_cache_tensor.size] + tensor = torch.zeros(cache_size_aligned, dtype=torch.int8, device=self.device) + tensor = self._align_memory(tensor, alignment)[: kv_cache_tensor.size] for layer_name_inner in kv_cache_tensor.shared_by: # shared the kvcache between the self_attn specs in the same group if "linear_attn" in layer_name_inner: kv_cache_raw_tensors[layer_name_inner] = tensor - elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys( - ): + elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors: # NOTE: We need to init k cache tensor (nope cache tensor in mla) and # v cache tensor (rope cache tensor in mla) separately to support prefill disaggregation, # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. if self.model_config.use_mla: - head_size = self.model_config.hf_text_config.qk_rope_head_dim + \ - self.model_config.hf_text_config.kv_lora_rank + head_size = ( + self.model_config.hf_text_config.qk_rope_head_dim + + self.model_config.hf_text_config.kv_lora_rank + ) dsa_k_cache_factor = None dsa_k_cache_size = None @@ -2521,59 +2392,42 @@ class NPUModelRunner(GPUModelRunner): dsa_k_cache_factor = 2 k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim - dsa_k_cache_size = int(kv_cache_tensor.size // - dsa_k_cache_factor) + dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor) else: # for other deepseek models, use MLAAttentionSpec k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim - k_tensor_size = int(kv_cache_tensor.size // - k_tensor_split_factor) - v_tensor_size = int(kv_cache_tensor.size // - v_tensor_split_factor) + k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) + v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) # for other attentions, e.g., self_attn, sliding window attn if self.vllm_config.kv_transfer_config is None: - k_tensor = torch.zeros(k_tensor_size, - dtype=torch.int8, - device=self.device) - v_tensor = torch.zeros(v_tensor_size, - dtype=torch.int8, - device=self.device) + k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device) + v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device) #### k cache: for deepseek sparse attention if dsa_k_cache_factor is not None: - dsa_k_cache_tensor = torch.zeros( - dsa_k_cache_size, - dtype=torch.int8, - device=self.device) + dsa_k_cache_tensor = torch.zeros(dsa_k_cache_size, dtype=torch.int8, device=self.device) else: - k_tensor = torch.zeros(k_tensor_size + alignment, - dtype=torch.int8, - device=self.device) - v_tensor = torch.zeros(v_tensor_size + alignment, - dtype=torch.int8, - device=self.device) - k_tensor = self._align_memory( - k_tensor, alignment)[:k_tensor_size] - v_tensor = self._align_memory( - v_tensor, alignment)[:v_tensor_size] + k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device) + v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device) + k_tensor = self._align_memory(k_tensor, alignment)[:k_tensor_size] + v_tensor = self._align_memory(v_tensor, alignment)[:v_tensor_size] #### k cache: for deepseek sparse attention if dsa_k_cache_factor is not None and dsa_k_cache_size is not None: dsa_k_cache_tensor = torch.zeros( - dsa_k_cache_size + alignment, - dtype=torch.int8, - device=self.device) - dsa_k_cache_tensor = self._align_memory( - dsa_k_cache_tensor, - alignment)[:dsa_k_cache_size] + dsa_k_cache_size + alignment, dtype=torch.int8, device=self.device + ) + dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size] for layer_name_inner in kv_cache_tensor.shared_by: # shared the kvcache between the self_attn specs in the same group - if ("attn" in layer_name_inner - and "linear_attn" not in layer_name_inner): - kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) if \ - not self.use_sparse else (k_tensor, v_tensor, dsa_k_cache_tensor) + if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: + kv_cache_raw_tensors[layer_name_inner] = ( + (k_tensor, v_tensor) + if not self.use_sparse + else (k_tensor, v_tensor, dsa_k_cache_tensor) + ) layer_names = set() for group in kv_cache_config.kv_cache_groups: @@ -2581,8 +2435,7 @@ class NPUModelRunner(GPUModelRunner): if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), "Some layers are not correctly initialized" return kv_cache_raw_tensors @@ -2602,7 +2455,7 @@ class NPUModelRunner(GPUModelRunner): Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ - kv_caches: Dict[str, torch.Tensor] = {} + kv_caches: dict[str, torch.Tensor] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -2616,15 +2469,15 @@ class NPUModelRunner(GPUModelRunner): raw_dsa_k_tensor = None if self.use_sparse: raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore - layer_name] + layer_name + ] assert raw_dsa_k_tensor is not None - sum_page_size_bytes = raw_k_tensor.numel( - ) + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() + sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() else: raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore - layer_name] - sum_page_size_bytes = raw_k_tensor.numel( - ) + raw_v_tensor.numel() + layer_name + ] + sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() assert raw_k_tensor is not None assert raw_v_tensor is not None assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0 @@ -2639,20 +2492,20 @@ class NPUModelRunner(GPUModelRunner): # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks - if hasattr(attn_backend, "get_supported_block_size" - ) and self.use_hybrid_blocks: + if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks * block_size_chunk, block_size, + num_blocks * block_size_chunk, + block_size, kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size) + kv_cache_spec.head_size, + ) else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size) + num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size + ) dtype = kv_cache_spec.dtype if not self.model_config.use_mla: k_shape = kv_cache_shape[1:] @@ -2661,12 +2514,16 @@ class NPUModelRunner(GPUModelRunner): # k_cache: nope_cache v_cache: rope_cache mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape k_shape = [ - mla_num_blocks, mla_block_size, num_kv_heads, - self.model_config.hf_text_config.kv_lora_rank + mla_num_blocks, + mla_block_size, + num_kv_heads, + self.model_config.hf_text_config.kv_lora_rank, ] v_shape = [ - mla_num_blocks, mla_block_size, num_kv_heads, - self.model_config.hf_text_config.qk_rope_head_dim + mla_num_blocks, + mla_block_size, + num_kv_heads, + self.model_config.hf_text_config.qk_rope_head_dim, ] k_cache = raw_k_tensor.view(dtype).view(k_shape) v_cache = raw_v_tensor.view(dtype).view(v_shape) @@ -2674,23 +2531,17 @@ class NPUModelRunner(GPUModelRunner): k_cache = maybe_trans_nz(k_cache) v_cache = maybe_trans_nz(v_cache) if self.use_sparse and raw_dsa_k_tensor is not None: - dsa_k_cache_shape = (num_blocks, - kv_cache_spec.block_size, 1, 128) - dsa_k_cache_size = ( - num_blocks - ) * kv_cache_spec.block_size * 128 * dtype.itemsize - dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view( - dtype).view(dsa_k_cache_shape) + dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128) + dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize + dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(dtype).view(dsa_k_cache_shape) kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) elif isinstance(kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor is not None - assert raw_tensor.numel( - ) % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel( - ) // kv_cache_spec.page_size_bytes + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes assert num_blocks >= kv_cache_config.num_blocks # `num_blocks` is the number of blocks the model runner can use. @@ -2704,16 +2555,13 @@ class NPUModelRunner(GPUModelRunner): state_tensors = [] target_idx = 0 start_idx = 0 - for shape, dtype in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): # normally, there is conv state and ssm state in this loop. And there is only # a conv state in some special models. target_shape = (num_blocks, *shape) - target_idx += torch.prod( - torch.tensor(target_shape)).item() - tensor = raw_tensor.view( - dtype)[start_idx:target_idx].view(target_shape) + target_idx += torch.prod(torch.tensor(target_shape)).item() + tensor = raw_tensor.view(dtype)[start_idx:target_idx].view(target_shape) start_idx = target_idx state_tensors.append(tensor) kv_caches[layer_name] = state_tensors @@ -2722,8 +2570,7 @@ class NPUModelRunner(GPUModelRunner): return kv_caches - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -2735,8 +2582,7 @@ class NPUModelRunner(GPUModelRunner): block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups - if not isinstance(kv_cache_group.kv_cache_spec, - EncoderOnlyAttentionSpec) + if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] # Generate kernel_block_sizes that matches each block_size @@ -2744,14 +2590,12 @@ class NPUModelRunner(GPUModelRunner): # use the supported block sizes from the backend # For other backends (like Mamba), use [0] (no splitting) kernel_block_sizes = [] - for kv_cache_group_id, kv_cache_group in enumerate( - kv_cache_config.kv_cache_groups): + for kv_cache_group_id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group.kv_cache_spec if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): # All layers in the UniformTypeKVCacheSpecs have the same type, # Pick an arbitrary one to dispatch. - kv_cache_spec = next( - iter(kv_cache_spec.kv_cache_specs.values())) + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): continue elif isinstance(kv_cache_spec, AttentionSpec): @@ -2768,9 +2612,7 @@ class NPUModelRunner(GPUModelRunner): supported_sizes = backend.get_supported_block_size() # If no specific sizes supported, use cache config # block_size - kernel_block_size_list = (supported_sizes - if supported_sizes else - [self.cache_config.block_size]) + kernel_block_size_list = supported_sizes if supported_sizes else [self.cache_config.block_size] else: # Fallback to cache config block_size if no backend found kernel_block_size_list = [self.cache_config.block_size] @@ -2782,17 +2624,15 @@ class NPUModelRunner(GPUModelRunner): # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] kernel_block_sizes.append([0]) - if block_sizes != [ - self.cache_config.block_size - ] or kernel_block_sizes != [[self.cache_config.block_size]]: + if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, - self.max_encoder_len), + max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2803,7 +2643,9 @@ class NPUModelRunner(GPUModelRunner): is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), + if self.vllm_config.speculative_config + else 0 + ), kernel_block_sizes=kernel_block_sizes, ) @@ -2811,8 +2653,7 @@ class NPUModelRunner(GPUModelRunner): """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + assert len(self.attn_groups) == 0, "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] @@ -2820,11 +2661,8 @@ class NPUModelRunner(GPUModelRunner): def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, - ) -> tuple[dict[AttentionGroupKey, list[str]], - set[type[AttentionBackend]]]: - layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, - kv_cache_group_spec.layer_names) + ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: + layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -2837,37 +2675,32 @@ class NPUModelRunner(GPUModelRunner): full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name] + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] key = (full_cls_name, layer_kv_cache_spec) - attn_backends[key] = AttentionGroupKey(attn_backend, - layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey(attn_backend, layer_kv_cache_spec) attn_backend_layers[key].append(layer_name) return ( - { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - }, - set(group_key.attn_backend - for group_key in attn_backends.values()), + {attn_backends[k]: v for k, v in attn_backend_layers.items()}, + set(group_key.attn_backend for group_key in attn_backends.values()), ) - def create_attn_groups(attn_backends_map: dict[AttentionBackend, - list[str]], - kv_cache_group_id: int) -> list[AttentionGroup]: + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], kv_cache_group_id: int + ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for (attn_backend, - kv_cache_spec), layer_names in attn_backends_map.items(): + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): attn_metadata_builders = [] - attn_metadata_builders.append(attn_backend.get_builder_cls()( - kv_cache_spec, - layer_names, - self.vllm_config, - self.device, - )) - attn_group = AttentionGroup(attn_backend, layer_names, - kv_cache_spec, kv_cache_group_id, - attn_metadata_builders) + attn_metadata_builders.append( + attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + ) + attn_group = AttentionGroup( + attn_backend, layer_names, kv_cache_spec, kv_cache_group_id, attn_metadata_builders + ) attn_groups.append(attn_group) return attn_groups @@ -2878,13 +2711,12 @@ class NPUModelRunner(GPUModelRunner): attention_backend_maps.append(attn_backends[0]) attention_backend_list.append(attn_backends[1]) - self._check_and_update_cudagraph_mode(attention_backend_list, - kv_cache_config.kv_cache_groups) + self._check_and_update_cudagraph_mode(attention_backend_list, kv_cache_config.kv_cache_groups) - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): + for i, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups): attn_backends = get_attn_backends_for_group( # type: ignore - kv_cache_group_spec) + kv_cache_group_spec + ) self.attn_groups.append(create_attn_groups(attn_backends[0], i)) # Calculate reorder batch threshold (if needed) @@ -2897,21 +2729,19 @@ class NPUModelRunner(GPUModelRunner): """ for group in self._attn_group_iterator(): attn_metadata_builder_i = group.get_metadata_builder() - if hasattr(attn_metadata_builder_i, - "reorder_batch_threshold"): # noqa + if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"): # noqa # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold if reorder_batch_threshold_i is not None: # noqa if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: + if reorder_batch_threshold_i != self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " - f"{self.reorder_batch_threshold}") + f"{self.reorder_batch_threshold}" + ) else: self.reorder_batch_threshold = reorder_batch_threshold_i # noqa @@ -2928,15 +2758,13 @@ class NPUModelRunner(GPUModelRunner): return {} kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase) + attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) # NOTE: Must process Attention/MLAAttention before MambaBase to maintain # ordering expected by graph parameter update logic in attention backends. mamba_layers: dict[str, MambaBase] = {} for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -2959,7 +2787,8 @@ class NPUModelRunner(GPUModelRunner): block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) elif spec := attn_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec @@ -2968,8 +2797,7 @@ class NPUModelRunner(GPUModelRunner): if len(mamba_layers) > 0: if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") + raise NotImplementedError("Prefix caching is not supported for Mamba yet.") for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec @@ -2981,8 +2809,7 @@ class NPUModelRunner(GPUModelRunner): attention_backends: list[set[type[AttentionBackend]]], kv_cache_groups: list[KVCacheGroupSpec], ) -> None: - super()._check_and_update_cudagraph_mode(attention_backends, - kv_cache_groups) + super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups) # NOTE: Since aclgraph_batch_sizes cannot be determined until here, # we set the graph params right before initializing the keys. @@ -2992,15 +2819,11 @@ class NPUModelRunner(GPUModelRunner): set_draft_graph_params(self.cudagraph_batch_sizes) def capture_model(self) -> None: - gpu_model_runner_cls = next((cls for cls in self.__class__.__mro__ - if cls.__name__ == "GPUModelRunner"), - None) + gpu_model_runner_cls = next((cls for cls in self.__class__.__mro__ if cls.__name__ == "GPUModelRunner"), None) if gpu_model_runner_cls is None: - raise TypeError("Could not find GPUModelRunner in the MRO. " - "The class hierarchy may have changed.") + raise TypeError("Could not find GPUModelRunner in the MRO. The class hierarchy may have changed.") parent_module_name = gpu_model_runner_cls.__module__ - with _torch_cuda_wrapper(), _replace_gpu_model_runner_function_wrapper( - parent_module_name): + with _torch_cuda_wrapper(), _replace_gpu_model_runner_function_wrapper(parent_module_name): GPUModelRunner.capture_model(self) def _prepare_multimodal_fields(self): @@ -3018,16 +2841,14 @@ class NPUModelRunner(GPUModelRunner): if req is None: continue - mm_data = getattr(req, 'multimodal_data', None) + mm_data = getattr(req, "multimodal_data", None) if not mm_data: continue for field in self.multimodal_cpu_fields: if field in mm_data: tensor = mm_data[field] - if isinstance( - tensor, - torch.Tensor) and tensor.device.type != 'cpu': + if isinstance(tensor, torch.Tensor) and tensor.device.type != "cpu": mm_data[field] = tensor.cpu() @@ -3042,15 +2863,12 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int: @contextmanager def _torch_cuda_wrapper(): - class _EventPlaceholder: - def __init__(self, *args, **kwargs) -> None: self.record = lambda: None self.synchronize = lambda: None class _StreamPlaceholder: - def __init__(self, *args, **kwargs) -> None: pass @@ -3090,7 +2908,7 @@ def _torch_cuda_wrapper(): def _replace_gpu_model_runner_function_wrapper(target_module_name): try: target_module = sys.modules[target_module_name] - setattr(target_module, "graph_capture", graph_capture) + setattr(target_module, "graph_capture", graph_capture) # noqa: B010 yield finally: - setattr(target_module, "graph_capture", graph_capture) + setattr(target_module, "graph_capture", graph_capture) # noqa: B010 diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 30dcbba1..9ff3c9e6 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -17,12 +17,11 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING import numpy as np import torch from vllm.config import VllmConfig -from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer if TYPE_CHECKING: @@ -36,6 +35,7 @@ class PCPManager: This manager encapsulates all PCP-related buffers and logic so that the ModelRunner can access them via `self.pcp_manager`. """ + num_reqs: int = 0 num_decode_reqs: int = 0 num_prefill_reqs: int = 0 @@ -59,9 +59,7 @@ class PCPManager: self.dcp_world_size = dcp_world_size self.dcp_world_rank = dcp_rank self.speculative_config = vllm_config.speculative_config - self.decode_threshold = 1 + ( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0) + self.decode_threshold = 1 + (self.speculative_config.num_speculative_tokens if self.speculative_config else 0) self.vllm_config = vllm_config self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs @@ -74,46 +72,42 @@ class PCPManager: pin_memory=pin_memory, ) self.pcp_padded_slot_mapping = torch.full( - (max_buffer_num_tokens, ), + (max_buffer_num_tokens,), fill_value=-1, dtype=torch.int32, device=device, ) self.pcp_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.total_num_sampled_tokens_pcp = 0 - self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), - device="cpu", - dtype=torch.int64) + self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs,), device="cpu", dtype=torch.int64) self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() self.pcp_unpad_mask_cpu_tensor = torch.ones( - (max_buffer_num_tokens, ), + (max_buffer_num_tokens,), device="cpu", dtype=torch.bool, ) self.num_actual_tokens_pcp_padded = 0 self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() self.full_indices = list( - range(self.max_num_tokens * self.pcp_world_size * - self.dcp_world_size + self.pcp_world_size * - self.dcp_world_size * self.max_num_reqs)) + range( + self.max_num_tokens * self.pcp_world_size * self.dcp_world_size + + self.pcp_world_size * self.dcp_world_size * self.max_num_reqs + ) + ) if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1: - self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, - dtype=torch.int32, - device=device, - pin_memory=pin_memory) - self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1, - dtype=torch.int32, - device=device, - pin_memory=pin_memory) - self.positions_pcp_full = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=pin_memory) + self.input_ids_pcp_full = CpuGpuBuffer( + self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory + ) + self.query_start_loc_pcp_full = CpuGpuBuffer( + self.max_num_reqs + 1, dtype=torch.int32, device=device, pin_memory=pin_memory + ) + self.positions_pcp_full = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory + ) self.positions_pcp_full_np = self.positions_pcp_full.numpy() - self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs, - dtype=torch.int32, - device=device, - pin_memory=pin_memory) + self.query_lens_pcp_full = CpuGpuBuffer( + self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory + ) def _get_cumsum_and_arange( self, @@ -130,8 +124,7 @@ class PCPManager: cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) total_num_tokens = cu_num_tokens[-1] # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] - cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, - num_scheduled_tokens) + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] arange = arange_np[:total_num_tokens] - cumsums_offsets @@ -143,15 +136,15 @@ class PCPManager: num_reqs: int, ) -> None: self.num_reqs = num_reqs - is_prefill = (num_scheduled_tokens[:num_reqs] > self.decode_threshold) + is_prefill = num_scheduled_tokens[:num_reqs] > self.decode_threshold if not any(is_prefill): first_prefill = num_reqs else: first_prefill = is_prefill.argmax() self.num_decode_reqs = first_prefill self.num_prefill_reqs = num_reqs - self.num_decode_reqs - self.num_decode_tokens = num_scheduled_tokens[:self.num_decode_reqs].sum() - + self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum() + def update_tokens_for_pcp( self, num_scheduled_tokens: np.ndarray, @@ -208,32 +201,29 @@ class PCPManager: # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). # We first pad each request's token count up to that multiple. - num_padded_scheduled_tokens = np.ceil( - num_scheduled_tokens / (2 * self.pcp_world_size)).astype( - np.int32) * (2 * self.pcp_world_size) + num_padded_scheduled_tokens = np.ceil(num_scheduled_tokens / (2 * self.pcp_world_size)).astype(np.int32) * ( + 2 * self.pcp_world_size + ) # PCP does not split decode requests. For decode requests, we instead # duplicate the scheduled tokens across the pcp_world_size ranks. - num_padded_scheduled_tokens[:self.num_decode_reqs] = ( - num_scheduled_tokens[:self.num_decode_reqs] * self.pcp_world_size) + num_padded_scheduled_tokens[: self.num_decode_reqs] = ( + num_scheduled_tokens[: self.num_decode_reqs] * self.pcp_world_size + ) # Record how many pads were added per request (padded - original). - self.num_pcp_pads_cpu[:self.num_reqs] = (num_padded_scheduled_tokens - - num_scheduled_tokens) + self.num_pcp_pads_cpu[: self.num_reqs] = num_padded_scheduled_tokens - num_scheduled_tokens # cu_padded_tokens: cumulative sum of padded token counts, # pcp_padded_arange: per-request arange flattened for padded tokens. - cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( - num_padded_scheduled_tokens, arange_np) + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np) # Build the mask that marks which positions in the padded allgather buffer # correspond to real (unpadded) tokens. - self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( - pcp_padded_arange < np.repeat(num_scheduled_tokens, - num_padded_scheduled_tokens)) - unpad_mask_decode = self.pcp_unpad_mask_cpu[:self.num_decode_tokens * - self.pcp_world_size] - unpad_mask_decode = unpad_mask_decode.reshape( - [-1, self.pcp_world_size]) + self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = pcp_padded_arange < np.repeat( + num_scheduled_tokens, num_padded_scheduled_tokens + ) + unpad_mask_decode = self.pcp_unpad_mask_cpu[: self.num_decode_tokens * self.pcp_world_size] + unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_world_size]) unpad_mask_decode[:, 0] = True unpad_mask_decode[:, 1:] = False pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size @@ -242,23 +232,20 @@ class PCPManager: # For prefill requests, we further split the pcp_tokens into two chunks # (head and tail). For decode requests, the chunk equals pcp_tokens. pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) - pcp_chunk_sizes[:self.num_decode_reqs] = pcp_tokens[:self.num_decode_reqs] + pcp_chunk_sizes[: self.num_decode_reqs] = pcp_tokens[: self.num_decode_reqs] # Build arange-style helpers for pcp tokens and chunk sizes: # - pcp_arange gives indices repeated for each token in pcp_tokens # - pcp_chunk_arange gives indices repeated for each position inside chunks _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) - _, pcp_chunk_arange = self._get_cumsum_and_arange( - pcp_chunk_sizes, arange_np) + _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes, arange_np) # Mask that marks whether a position belongs to the head chunk (True) # or the tail chunk (False). For decode requests, tail chunk won't exist # and is handled specially below. - pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, - pcp_tokens) + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) - def get_current_rank_positions(positions_start_loc: int | np.ndarray, - rank: int): + def get_current_rank_positions(positions_start_loc: int | np.ndarray, rank: int): """ Compute flattened positions for the given rank with a given start offset for each request (positions_start_loc). @@ -271,59 +258,53 @@ class PCPManager: """ positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) head_start_loc = positions_start_loc + rank * pcp_chunk_sizes - tail_start_loc = ( - positions_start_loc + - (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes) + tail_start_loc = positions_start_loc + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes # Fill head positions using chunk arange offset by head_start_loc. - positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( - head_start_loc, pcp_chunk_sizes) + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(head_start_loc, pcp_chunk_sizes) # Fill tail positions. Note decode requests do not have tail chunks, # so the tail filling is only for prefill positions. positions[~pcp_head_chunk_mask] = ( - pcp_chunk_arange[self.num_decode_tokens:] + - np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens:]) + pcp_chunk_arange[self.num_decode_tokens :] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens :] + ) return positions positions = get_current_rank_positions(0, self.pcp_world_rank) # Decode tokens are duplicated only after AG. But their positions are # same without prefill context parallel. if self.num_decode_reqs > 0: - positions[:self.num_decode_tokens] = self._get_cumsum_and_arange( - num_scheduled_tokens[:self.num_decode_reqs], arange_np)[1] + positions[: self.num_decode_tokens] = self._get_cumsum_and_arange( + num_scheduled_tokens[: self.num_decode_reqs], arange_np + )[1] # Build the restore index used after allgather. padded_pos_start_loc = np.roll(cu_padded_tokens, 1) padded_pos_start_loc[0] = 0 all_positions_lst = [ - get_current_rank_positions(padded_pos_start_loc, rank_i) - for rank_i in range(self.pcp_world_size) + get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size) ] all_positions = np.concatenate(all_positions_lst) - self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( - all_positions.argsort()) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort() self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) - self.pcp_tokens[:self.num_reqs] = pcp_tokens[:self.num_reqs] - self.total_num_sampled_tokens_pcp = pcp_tokens[:self.num_reqs].sum() + self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs] + self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum() return ( - pcp_tokens[:self.num_reqs], + pcp_tokens[: self.num_reqs], positions, ) def get_logits_indices(self, cu_num_tokens: np.ndarray): - return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - - self.num_pcp_pads_cpu_tensor[:self.num_reqs] - 1) + return torch.from_numpy(cu_num_tokens) * self.pcp_world_size - self.num_pcp_pads_cpu_tensor[: self.num_reqs] - 1 - def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, - slot_mapping: torch.Tensor): + def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor): # After pcp allgather and restore, there are padded tokens in kv, # so we need pad slotmapping for alignment. - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens_padded * self.pcp_world_size] + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size] - cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * - self.pcp_world_size] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[: num_tokens * self.pcp_world_size] pcp_padded_slot_mapping.fill_(-1) - pcp_padded_slot_mapping[:num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping + pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping return pcp_padded_slot_mapping def get_restore_hidden_states( @@ -333,13 +314,12 @@ class PCPManager: # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. from vllm.distributed.parallel_state import get_pcp_group + hidden_states = get_pcp_group().all_gather( - hidden_states[:self.num_actual_tokens_pcp_padded // - self.pcp_world_size], + hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size], 0, ) - restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states. - shape[0]] + restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] return torch.index_select( hidden_states, 0, @@ -369,73 +349,61 @@ class PCPManager: num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32) for i, req_id in enumerate(input_batch.req_ids): num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - self.query_lens_pcp_full.cpu[:self.num_reqs] = torch.from_numpy( - num_scheduled_tokens_pcp_full) - req_indices_pcp_full = np.repeat(arange_np[:self.num_reqs], - num_scheduled_tokens_pcp_full) + self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens_pcp_full) + req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full) cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) self.query_start_loc_pcp_full.np[0] = 0 - self.query_start_loc_pcp_full.np[1:self.num_reqs + - 1] = cu_num_tokens_pcp_full - self.query_start_loc_pcp_full.np[self.num_reqs + 1:].fill(-1) + self.query_start_loc_pcp_full.np[1 : self.num_reqs + 1] = cu_num_tokens_pcp_full + self.query_start_loc_pcp_full.np[self.num_reqs + 1 :].fill(-1) cumsums_offsets_pcp_full = np.repeat( - cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, - num_scheduled_tokens_pcp_full) + cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, num_scheduled_tokens_pcp_full + ) arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full - positions_pcp_full_np = self.positions_pcp_full_np[: - total_num_scheduled_tokens_pcp_full] - np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], - arange_pcp_full, - out=positions_pcp_full_np) - token_indices_pcp_full = ( - positions_pcp_full_np + - req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]) - torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices_pcp_full), - out=self.input_ids_pcp_full. - cpu[:total_num_scheduled_tokens_pcp_full]) + positions_pcp_full_np = self.positions_pcp_full_np[:total_num_scheduled_tokens_pcp_full] + np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], arange_pcp_full, out=positions_pcp_full_np) + token_indices_pcp_full = positions_pcp_full_np + req_indices_pcp_full * input_batch.token_ids_cpu.shape[1] + torch.index_select( + input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_pcp_full), + out=self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full], + ) if self.use_async_scheduling: - self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids, - scheduler_output, - total_num_scheduled_tokens, - cu_num_tokens_pcp_full, - num_spec_tokens) + self._update_input_ids_pcp_full_ids( + input_batch, + draft_token_ids, + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens_pcp_full, + num_spec_tokens, + ) self.query_lens_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu() - self.input_ids_pcp_full.copy_to_gpu( - total_num_scheduled_tokens_pcp_full) + self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full) self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full # For mtpx, pre-allocate mtp slot_mapping here if self.decode_threshold > 2 and not with_prefill: num_tokens_ori = sum(list(num_scheduled_tokens.values())) - num_tokens_mtp = \ - num_tokens_ori + self.num_reqs * (self.decode_threshold - 2) + num_tokens_mtp = num_tokens_ori + self.num_reqs * (self.decode_threshold - 2) num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size - req_indices_split = np.array_split(req_indices, - cu_num_tokens)[:self.num_reqs] - positions_split = np.array_split(positions_np, - cu_num_tokens)[:self.num_reqs] + req_indices_split = np.array_split(req_indices, cu_num_tokens)[: self.num_reqs] + positions_split = np.array_split(positions_np, cu_num_tokens)[: self.num_reqs] for req_idx in range(self.num_reqs): ori_req_indice = req_indices_split[req_idx] ori_position = positions_split[req_idx] req_indices_split[req_idx] = np.append( - ori_req_indice, - np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) + ori_req_indice, np.repeat(ori_req_indice[-1], self.decode_threshold - 2) + ) positions_split[req_idx] = np.append( - ori_position, - np.arange(ori_position[-1] + 1, - ori_position[-1] + self.decode_threshold - 1)) + ori_position, np.arange(ori_position[-1] + 1, ori_position[-1] + self.decode_threshold - 1) + ) req_indices_mtp = np.concatenate(req_indices_split) positions_mtp = np.concatenate(positions_split) - input_batch.block_table.compute_slot_mapping( - req_indices_mtp, positions_mtp) - mtp_slot_ori = input_batch.block_table.block_tables[ - 0].slot_mapping.cpu[:num_tokens_mtp] + input_batch.block_table.compute_slot_mapping(req_indices_mtp, positions_mtp) + mtp_slot_ori = input_batch.block_table.block_tables[0].slot_mapping.cpu[:num_tokens_mtp] unpad_mask = np.repeat(False, num_tokens_mtp_pad) - unpad_mask[::self.pcp_world_size] = True - mtp_slot_pad = \ - torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) + unpad_mask[:: self.pcp_world_size] = True + mtp_slot_pad = torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) mtp_slot_pad[unpad_mask] = mtp_slot_ori self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) @@ -454,8 +422,7 @@ class PCPManager: from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" - if (input_batch.prev_sampled_token_ids is None - or input_batch.prev_req_id_to_index is None): + if input_batch.prev_sampled_token_ids is None or input_batch.prev_req_id_to_index is None: return # Async scheduling case, where some decode requests from the previous @@ -481,9 +448,7 @@ class PCPManager: # sample_flattened_indices = [0, 2, 5] # spec_flattened_indices = [1, 3, 4, 6, 7] sample_flattened_indices.append(flattened_index - draft_len) - spec_flattened_indices.extend( - range(flattened_index - draft_len + 1, - flattened_index + 1)) + spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1)) start = prev_index * num_spec_tokens # prev_draft_token_indices is used to find which draft_tokens_id # should be copied to input_ids @@ -491,8 +456,7 @@ class PCPManager: # flatten draft_tokens_id [1,2,3,4,5,6] # draft_len of each request [1, 2, 1] # then prev_draft_token_indices is [0, 2, 3, 4] - prev_draft_token_indices.extend(range(start, - start + draft_len)) + prev_draft_token_indices.extend(range(start, start + draft_len)) num_commmon_tokens = len(sample_flattened_indices) if num_commmon_tokens == 0: @@ -500,15 +464,12 @@ class PCPManager: # So input_ids.cpu will have all the input ids. return # Upload the index tensors asynchronously so the scatter can be non-blocking. - sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, - dtype=torch.int64) - prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, - dtype=torch.int64) + sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, dtype=torch.int64) + prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, dtype=torch.int64) self.input_ids_pcp_full.cpu.scatter_( dim=0, index=sampled_tokens_index_tensor, - src=input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0].cpu(), + src=input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0].cpu(), ) # Scatter the draft tokens after the sampled tokens are scattered. @@ -516,10 +477,8 @@ class PCPManager: return assert isinstance(draft_token_ids, torch.Tensor) - draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, - dtype=torch.int64) - prev_draft_token_indices_tensor = torch.tensor( - prev_draft_token_indices, dtype=torch.int64) + draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, dtype=torch.int64) + prev_draft_token_indices_tensor = torch.tensor(prev_draft_token_indices, dtype=torch.int64) # because input_ids dtype is torch.int32, # so convert draft_token_ids to torch.int32 here. @@ -528,8 +487,7 @@ class PCPManager: self.input_ids_pcp_full.cpu.scatter_( dim=0, index=draft_tokens_index_tensor, - src=draft_token_ids.flatten() - [prev_draft_token_indices_tensor].cpu(), + src=draft_token_ids.flatten()[prev_draft_token_indices_tensor].cpu(), ) def _get_cp_local_seq_lens( @@ -545,41 +503,32 @@ class PCPManager: num_requests = seq_lens.size(0) total_world_size = pcp_world_size * dcp_world_size seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) - rank_offsets = (torch.arange(total_world_size, - dtype=torch.int32).unsqueeze(0).repeat( - num_requests, 1)) - base = (seq_lens_tiled // cp_kv_cache_interleave_size // - total_world_size * cp_kv_cache_interleave_size) + rank_offsets = torch.arange(total_world_size, dtype=torch.int32).unsqueeze(0).repeat(num_requests, 1) + base = seq_lens_tiled // cp_kv_cache_interleave_size // total_world_size * cp_kv_cache_interleave_size remainder = seq_lens_tiled - base * total_world_size remainder = torch.clip( remainder - rank_offsets * cp_kv_cache_interleave_size, 0, cp_kv_cache_interleave_size, ) - dcp_local_seq_lens = (base + remainder).reshape( - [-1, pcp_world_size, dcp_world_size]) + dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size]) return dcp_local_seq_lens - def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, - input_batch, num_scheduled_tokens): - from vllm_ascend.attention.utils import \ - AscendPrefillContextParallelMetadata + def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens): + from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata + num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None if self.pcp_world_size * self.dcp_world_size > 1: - decode_context_lens = input_batch.num_computed_tokens_cpu[: - self.num_decode_reqs] + num_scheduled_tokens[: - self.num_decode_reqs] - prefill_context_lens = input_batch.num_computed_tokens_cpu[ - self.num_decode_reqs:self.num_reqs] - context_lens = np.concatenate( - [decode_context_lens, prefill_context_lens]) + decode_context_lens = ( + input_batch.num_computed_tokens_cpu[: self.num_decode_reqs] + + num_scheduled_tokens[: self.num_decode_reqs] + ) + prefill_context_lens = input_batch.num_computed_tokens_cpu[self.num_decode_reqs : self.num_reqs] + context_lens = np.concatenate([decode_context_lens, prefill_context_lens]) num_computed_tokens_of_pcp_dcp = torch.zeros( - [ - self.num_reqs * self.decode_threshold, self.pcp_world_size, - self.dcp_world_size - ], + [self.num_reqs * self.decode_threshold, self.pcp_world_size, self.dcp_world_size], dtype=torch.int32, ) # For pcp + spec decode, we flatten seq_lens @@ -587,41 +536,37 @@ class PCPManager: # Same as block_table, we flatten decode seq_lens to query_lens, # and keep prefill seq_lens unchanged. for decode_idx in range(self.decode_threshold): - num_computed_tokens_of_pcp_dcp[ - self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ + num_computed_tokens_of_pcp_dcp[self.decode_threshold - 1 - decode_idx :: self.decode_threshold] = ( self._get_cp_local_seq_lens( torch.tensor(context_lens) - decode_idx, self.pcp_world_size, self.dcp_world_size, self.vllm_config.parallel_config.cp_kv_cache_interleave_size, ) + ) if self.decode_threshold > 1: num_computed_tokens_of_pcp_dcp_list = [] if self.num_decode_reqs: - num_decodes_flatten = \ - query_lens[:self.num_decode_reqs].sum().item() - if query_lens[:self.num_decode_reqs].min().item( - ) == self.decode_threshold: + num_decodes_flatten = query_lens[: self.num_decode_reqs].sum().item() + if query_lens[: self.num_decode_reqs].min().item() == self.decode_threshold: decode_flatten_idx = list(range(num_decodes_flatten)) else: decode_flatten_idx = [] for req_id in range(self.num_decode_reqs): offset = (req_id + 1) * self.decode_threshold - decode_flatten_idx += \ - list(range(offset - query_lens[req_id], offset)) - num_computed_tokens_of_pcp_dcp_list.append( - num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) + decode_flatten_idx += list(range(offset - query_lens[req_id], offset)) + num_computed_tokens_of_pcp_dcp_list.append(num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) if self.num_prefill_reqs: num_computed_tokens_of_pcp_dcp_list.append( num_computed_tokens_of_pcp_dcp[ - (self.num_decode_reqs + 1) * self.decode_threshold - - 1::self.decode_threshold]) - num_computed_tokens_of_pcp_dcp = torch.cat( - num_computed_tokens_of_pcp_dcp_list, dim=0) + (self.num_decode_reqs + 1) * self.decode_threshold - 1 :: self.decode_threshold + ] + ) + num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0) long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. - numpy()) + num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(), + ) if self.pcp_world_size > 1: q_head_idx, q_tail_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] @@ -639,109 +584,102 @@ class PCPManager: continue chunk_len = seq_len // 2 chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) + q_head_idx.extend(list(range(q_req_offset, q_req_offset + chunk_len))) kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) + list(range(kv_req_offset, kv_req_offset + chunk_len * q_head_chunk_id)) + ) kv_with_q_head_mask_idx.extend( list( range( kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) + kv_req_offset + chunk_len * (q_head_chunk_id + 1), + ) + ) + ) + kv_with_q_head_nomask_seqlens.append(chunk_len * q_head_chunk_id) split_with_q_head_nomask_idx_reqs.append( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) + list(range(kv_req_offset, kv_req_offset + chunk_len * q_head_chunk_id)) + ) + q_tail_idx.extend(list(range(q_req_offset + chunk_len, q_req_offset + chunk_len * 2))) kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) + list(range(kv_req_offset, kv_req_offset + chunk_len * q_tail_chunk_id)) + ) kv_with_q_tail_mask_idx.extend( list( range( kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) + kv_req_offset + chunk_len * (q_tail_chunk_id + 1), + ) + ) + ) + kv_with_q_tail_nomask_seqlens.append(chunk_len * q_tail_chunk_id) split_kv_with_q_tail_nomask_idx_reqs.append( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) + list(range(kv_req_offset, kv_req_offset + chunk_len * q_tail_chunk_id)) + ) q_req_offset += seq_len kv_req_offset += seq_len * self.pcp_world_size - q_head_idx_tensor = self._list_to_tensor( - q_head_idx, self.device) - q_tail_idx_tensor = self._list_to_tensor( - q_tail_idx, self.device) + q_head_idx_tensor = self._list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = self._list_to_tensor(q_tail_idx, self.device) self.q_head_idx_tensor = q_head_idx_tensor self.q_tail_idx_tensor = q_tail_idx_tensor q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to( - torch.int32) + q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) self.q_full_idx = q_full_idx self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': - kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': - kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + "kv_with_q_head_nomask_idx_tensor": kv_with_q_head_nomask_idx, + "kv_with_q_head_mask_idx_tensor": kv_with_q_head_mask_idx, + "kv_with_q_tail_nomask_idx_tensor": kv_with_q_tail_nomask_idx, + "kv_with_q_tail_mask_idx_tensor": kv_with_q_tail_mask_idx, } for key, value in self.kv_idx_names.items(): tensor_npu = self._list_to_tensor(value, self.device) self.kv_idx_names[key] = tensor_npu - attn_mask_seqlens = torch.tensor( - [chunk_seqlens, chunk_seqlens], dtype=torch.int32) + attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], dtype=torch.int32) head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) + [chunk_seqlens, kv_with_q_head_nomask_seqlens], dtype=torch.int32 + ) tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32 + ) if self.vllm_config.model_config.use_mla: - split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list( + ( + split_q_head_nomask_idx_tensor_list, + split_q_tail_nomask_idx_tensor_list, + head_attn_nomask_seqlens_list, + tail_attn_nomask_seqlens_list, + ) = self._split_nomask_idx_tensor_list( split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs, - head_attn_nomask_seqlens, chunk_seqlens) + head_attn_nomask_seqlens, + chunk_seqlens, + ) self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens + "attn_mask_seqlens": attn_mask_seqlens, + "head_attn_nomask_seqlens": head_attn_nomask_seqlens, + "tail_attn_nomask_seqlens": tail_attn_nomask_seqlens, } - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: - num_actual_tokens_pcp_padded] + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[ + :num_actual_tokens_pcp_padded + ] long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_full_idx = self.q_full_idx long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_mask_idx_tensor'] + "kv_with_q_head_nomask_idx_tensor" + ] + long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names["kv_with_q_head_mask_idx_tensor"] long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_mask_idx_tensor'] - long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ - 'attn_mask_seqlens'] - long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'head_attn_nomask_seqlens'] - long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'tail_attn_nomask_seqlens'] + "kv_with_q_tail_nomask_idx_tensor" + ] + long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names["kv_with_q_tail_mask_idx_tensor"] + long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs["attn_mask_seqlens"] + long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs["head_attn_nomask_seqlens"] + long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs["tail_attn_nomask_seqlens"] if self.vllm_config.model_config.use_mla: long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list @@ -755,46 +693,53 @@ class PCPManager: tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True) return tensor_npu - def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs, - split_kv_with_q_tail_nomask_idx_reqs, - head_attn_nomask_seqlens, chunk_seqlens): - split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], [] + def _split_nomask_idx_tensor_list( + self, + split_with_q_head_nomask_idx_reqs, + split_kv_with_q_tail_nomask_idx_reqs, + head_attn_nomask_seqlens, + chunk_seqlens, + ): + split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list = [], [] head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], [] if split_with_q_head_nomask_idx_reqs: - #In long-sequence scenarios, the computational cost and latency - #of the _npu_ring_mla operator are not proportional, so we split - #long sequences into shorter ones to improve performance. + # In long-sequence scenarios, the computational cost and latency + # of the _npu_ring_mla operator are not proportional, so we split + # long sequences into shorter ones to improve performance. split_size = 16 * 1024 if self.pcp_world_rank == 0: - split_q_head_nomask_idx_list = [ - self.kv_idx_names['kv_with_q_head_nomask_idx_tensor'] - ] + split_q_head_nomask_idx_list = [self.kv_idx_names["kv_with_q_head_nomask_idx_tensor"]] else: split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx( - split_with_q_head_nomask_idx_reqs, split_size) + split_with_q_head_nomask_idx_reqs, split_size + ) split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx( - split_kv_with_q_tail_nomask_idx_reqs, split_size) + split_kv_with_q_tail_nomask_idx_reqs, split_size + ) for q_head_nomask_idx in split_q_head_nomask_idx_list: - split_q_head_nomask_idx_tensor_list.append( - self._list_to_tensor(q_head_nomask_idx, self.device)) + split_q_head_nomask_idx_tensor_list.append(self._list_to_tensor(q_head_nomask_idx, self.device)) for q_tail_nomask_idx in split_q_tail_nomask_idx_list: - split_q_tail_nomask_idx_tensor_list.append( - self._list_to_tensor(q_tail_nomask_idx, self.device)) + split_q_tail_nomask_idx_tensor_list.append(self._list_to_tensor(q_tail_nomask_idx, self.device)) if self.pcp_world_rank == 0: head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens] else: for q_head_nomask_lens in split_q_head_nomask_lens_list: head_attn_nomask_seqlens_list.append( - torch.tensor([chunk_seqlens, q_head_nomask_lens], - dtype=torch.int32)) + torch.tensor([chunk_seqlens, q_head_nomask_lens], dtype=torch.int32) + ) for q_tail_nomask_lens in split_q_tail_nomask_lens_list: tail_attn_nomask_seqlens_list.append( - torch.tensor([chunk_seqlens, q_tail_nomask_lens], - dtype=torch.int32)) - return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list + torch.tensor([chunk_seqlens, q_tail_nomask_lens], dtype=torch.int32) + ) + return ( + split_q_head_nomask_idx_tensor_list, + split_q_tail_nomask_idx_tensor_list, + head_attn_nomask_seqlens_list, + tail_attn_nomask_seqlens_list, + ) def _split_multi_batch_kv_idx( self, @@ -813,7 +758,7 @@ class PCPManager: current_batch_len = [] for t in range(time): start = t * split_size - current_segment = single_batch[start:start + split_size] + current_segment = single_batch[start : start + split_size] current_batch_split.append(current_segment) current_batch_len.append(len(current_segment)) @@ -829,8 +774,9 @@ class PCPManager: def reshape_kv_len_to_time_first(split_kv_len_2d): if not split_kv_len_2d or not split_kv_len_2d[0]: return [] - return [[batch_len[time_idx] for batch_len in split_kv_len_2d] - for time_idx in range(len(split_kv_len_2d[0]))] + return [ + [batch_len[time_idx] for batch_len in split_kv_len_2d] for time_idx in range(len(split_kv_len_2d[0])) + ] merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d) return merged_split_kv_idx_3d, merged_split_kv_len_2d