Files
xc-llm-ascend/vllm_ascend/worker/v2/aclgraph_utils.py

248 lines
9.3 KiB
Python
Raw Normal View History

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from contextlib import contextmanager
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import vllm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import logger
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
[Main2Main] Upgrade vllm commit to 0123 (#6169) ### What this PR does / why we need it? 1. ✅ Upgrade vllm commit to: 0115 (8471b27df97c3eb79f891802fc0e858f8f7ac6a0) Modify import paths due to the refactors: https://github.com/vllm-project/vllm/pull/32245 https://github.com/vllm-project/vllm/pull/32060 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913 2. ✅Upgrade vllm commit to: 0119 (9a1f16da1e423ede2c2f52a9850cbfbb39cefe96) Fix `WorkerProc.__init__() missing 1 required positional argument: 'is_driver_worker'` due to https://github.com/vllm-project/vllm/pull/28506 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569 3. ✅Upgrade vllm commit to: 0120(148117ea2e689cd43df4be6892671a17cdae5833) 1. Add `skip_compiled` param in `set_forward_context` due to https://github.com/vllm-project/vllm/pull/30385 2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to https://github.com/vllm-project/vllm/pull/24322 change `self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size` 3. Modify UT import paths due to the refactors:https://github.com/vllm-project/vllm/pull/32060 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946 4. ✅Upgrade vllm commit to: 0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9) 1. vLLM switched `uses_mrope` from target to draft model config, making `positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's direct self.positions access and tests missing `draft_model_config.uses_mrope`. https://github.com/vllm-project/vllm/pull/32048 2. Moved bs_to_padded_graph_size from CompilationConfig to CudagraphDispatcher due to the refactor https://github.com/vllm-project/vllm/pull/30143 3. Remove unused `maybe_setup_kv_connector` due to https://github.com/vllm-project/vllm/pull/32077 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834 6. ✅Upgrade vllm commit to: 0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5) Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig due to https://github.com/vllm-project/vllm/pull/32414 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054 8. ✅Upgrade vllm commit to: 0123(dc917cceb877dfd13f98c538c4c96158047d98bd) Setting temperature=0.0 due to the removal of the default temperature value in https://github.com/vllm-project/vllm/pull/32723 Test result: https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 --------- Signed-off-by: wjunLu <wjunlu217@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Co-authored-by: wjunLu <wjunlu217@gmail.com>
2026-01-27 08:44:36 +08:00
[Main2Main][Deps][Misc] Upgrade vLLM to v0.15.0 (#6470) ### What this PR does / why we need it? This PR upgrades the vLLM dependency from `v0.14.1` to `v0.15.0`. This involves: - Updating the `VLLM_TAG` in all `Dockerfile`. - Updating the vLLM version in `docs/source/conf.py`. - Removing conditional code paths specific to `v0.14.1` across the codebase, which simplifies maintenance. - Fix `TypeError: MMEncoderAttention.__init__() got an unexpected keyword argument 'multimodal_config'` due to https://github.com/vllm-project/vllm/pull/31972. - Fix `_shared_experts: 'NoneType' object is not callable` due to https://github.com/vllm-project/vllm/pull/32082 by https://github.com/vllm-project/vllm-ascend/pull/6335. - Fix `ReshapeAndCacheOperation setup failed!` due to https://github.com/vllm-project/vllm/pull/25954 by overriding attention metadata slots. This upgrade is necessary to keep the project aligned with the latest features, bug fixes, and API changes in the vLLM project. ### Does this PR introduce _any_ user-facing change? No, this is an internal dependency update and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with these changes, ensuring that all existing tests are successful with the new vLLM version. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd co-authored-by: shen-shanshan <467638484@qq.com> --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-02-02 15:57:55 +08:00
class AclGraphManager(CudaGraphManager):
"""ACL Graph Manager for Ascend NPUs."""
def __init__(
self,
vllm_config: VllmConfig,
use_aux_hidden_state_outputs: bool,
device: torch.device,
model_runner: Any, # NPUModelRunner type, in case circular import, so we pass it as Any
):
# set model runner attribute, so we can access attributes model runner
# when call `run_fullgraph` method in CudaGraphManager,
# then we don't need to # copy `execute_model` method in `NPUModelRunner` class.
self.model_runner = model_runner
super().__init__(
vllm_config,
use_aux_hidden_state_outputs,
device,
)
# vllm-ascend need to update graph params of attention backend.
# so we need to set graph params before capture full graph.
if super().needs_capture():
set_graph_params(self.cudagraph_sizes)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
"""Override _capture_full_graph because we need to set capturing=True in forward context."""
# set capturing=True in before model forward.
model = ModelWithContext(model)
return super()._capture_full_graph(
num_tokens,
num_reqs,
model,
input_ids,
positions,
inputs_embeds,
num_tokens_across_dp,
attn_metadata,
slot_mappings,
has_lora,
)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #10) (#6173) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | |`vllm_ascend/ops/layer_shard_linear.py`| |`vllm_ascend/ops/linear.py`| |`vllm_ascend/ops/linear_op.py`| |`vllm_ascend/worker/worker.py`| | ` vllm_ascend/patch/worker/patch_bert.py` | | ` vllm_ascend/patch/worker/patch_deepseek.py` | | ` vllm_ascend/patch/worker/patch_distributed.py` | | ` vllm_ascend/patch/worker/patch_module.py` | | ` vllm_ascend/patch/worker/patch_multimodal_merge.py` | | ` vllm_ascend/patch/worker/patch_qwen3_next.py` | | ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` | | ` vllm_ascend/patch/worker/patch_rejection_sampler.py` | | ` vllm_ascend/patch/worker/patch_rope.py` | | ` vllm_ascend/patch/worker/patch_triton.py` | | ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` | | ` vllm_ascend/patch/worker/patch_v2_egale.py` | |` vllm_ascend/worker/npu_input_batch.py`| |` vllm_ascend/worker/v2/aclgraph_utils.py`| |` vllm_ascend/worker/v2/attn_utils.py`| |` vllm_ascend/worker/v2/model_runner.py`| |` vllm_ascend/worker/v2/sample/gumbel.py`| |` vllm_ascend/worker/v2/sample/penalties.py`| |` vllm_ascend/worker/v2/sample/sampler.py`| |` vllm_ascend/worker/v2/spec_decode/__init__.py`| |` vllm_ascend/worker/v2/spec_decode/eagle.py`| |` vllm_ascend/worker/v2/states.py`| ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com> Signed-off-by: SILONG ZENG <2609716663@qq.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-02-06 15:35:06 +08:00
with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
super().capture_graph(
num_tokens,
capture_cg_mode,
model,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
has_lora,
uniform_decode,
)
def run_fullgraph(self, num_tokens: int) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""Override run_fullgraph to update full graph params in run_fullgraph."""
logger.info_once(f"run_fullgraph with num_tokens={num_tokens}")
ret = super().run_fullgraph(num_tokens)
assert self.model_runner.cudagraph_and_dp_padding is not None
positions = self.model_runner.input_buffers.positions[:num_tokens]
_num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
self.model_runner.cudagraph_and_dp_padding
)
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
with set_forward_context(
self.model_runner.input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=None, # Full graph model don't need batch_descriptor
slot_mapping=self.model_runner.input_batch.slot_mappings,
):
forward_context = get_forward_context()
update_full_graph_params(
# FIXME(Ronald1995): support hybrid attn backend
list(self.model_runner.attn_backends.values())[0],
self.model_runner.update_stream,
forward_context,
num_tokens,
self.vllm_config,
self.model_runner.speculative_config,
positions.shape[0],
)
return ret
def is_uniform_decode(
self,
num_reqs: int,
num_tokens: int,
max_query_len: int,
):
return (max_query_len == self.uniform_decode_query_len) and (num_tokens == max_query_len * num_reqs)
@contextmanager
def prepare_capture_inputs_wrapper():
"""Context manager to override input preparation for NPU graph capture."""
# TODO(Ronald1995): make prepare_inputs_to_capture as static method
# in CudaGraphManager.
ori = vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture
try:
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = prepare_inputs_to_capture
yield
finally:
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = ori
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.seq_lens_cpu[:num_reqs] = num_tokens
input_buffers.seq_lens_cpu[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, kv_cache_config)
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_np=input_buffers.seq_lens_np,
)
return attn_metadata, slot_mappings_by_layer
class ModelWithContext(nn.Module):
"""Define a wrapper model to inject forward context.
so we can inherit vllm's CudaGraphManager._capture_full_graph.
"""
def __init__(self, original_model):
super().__init__()
self.original_model = original_model
def forward(self, *args, **kwargs):
# In warmup phase, capturing=False by default.
# when capturing, we need to set capturing=True in forward context.
_EXTRA_CTX.capturing = True
return self.original_model(*args, **kwargs)