diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 19b021e..ed0761b 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -115,24 +115,26 @@ jobs: - name: Install vllm-project/vllm-ascend run: | pip install -r requirements-dev.txt - pip install -e . - - - name: Run vllm-project/vllm-ascend test on V0 engine - env: - VLLM_USE_V1: 0 - run: | - if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then - pytest -sv tests/singlecard/test_offline_inference.py - pytest -sv tests/ops - else - pytest -sv tests/multicard/test_offline_inference_distributed.py - pytest -sv tests/ops - fi + pip install -v --no-build-isolation -e . - name: Run vllm-project/vllm-ascend test for V1 Engine env: VLLM_USE_V1: 1 VLLM_WORKER_MULTIPROC_METHOD: spawn + run: | + if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then + pytest -sv tests/singlecard/test_offline_inference.py + pytest -sv tests/ops + pytest -sv tests/compile + else + pytest -sv tests/multicard/test_offline_inference_distributed.py + pytest -sv tests/ops + pytest -sv tests/compile + fi + + - name: Run vllm-project/vllm-ascend test on V0 engine + env: + VLLM_USE_V1: 0 run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/test_offline_inference.py diff --git a/csrc/ops.h b/csrc/ops.h index 4296796..aaac630 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -21,6 +21,7 @@ #include #include "kernels/types.h" +#include "torch_npu/csrc/aten/common/from_blob.h" namespace vllm_ascend { extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst, @@ -29,4 +30,20 @@ namespace vllm_ascend { const int64_t dstKeyStride, const int numHeads, const int numKvHeads, const int headSize, const int64_t numTokens, const uint32_t loopCnt, uint32_t aivNum); -} \ No newline at end of file + + torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { + if (!tensor.is_privateuseone()) { + throw std::runtime_error("Tensor must be on NPU device"); + } + // Get the raw data pointer + void* data_ptr = tensor.data_ptr(); + // Get tensor sizes and strides + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + // Get tensor options (dtype, device) + auto options = tensor.options(); + // Create a new tensor from the raw data pointer + auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options); + return new_tensor; + } +} diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index b874a43..94e1fd6 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -103,6 +103,8 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T TORCH_LIBRARY_EXPAND(_C, ops) { // vLLM-Ascend custom ops + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); + ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor); // Rotary embedding // Apply GPT-NeoX style rotary embedding to query and key. diff --git a/pyproject.toml b/pyproject.toml index ac81c21..ee0a440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,8 @@ requires = [ "scipy", "setuptools>=64", "setuptools-scm>=8", - "torch_npu", - "torch >= 2.5.1", + "torch_npu==2.5.1rc1", + "torch>=2.5.1", "torchvision<0.21.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements-dev.txt b/requirements-dev.txt index 0113f76..9bd9239 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ -r requirements-lint.txt +-r requirements.txt modelscope pytest >= 6.0 pytest-asyncio diff --git a/requirements.txt b/requirements.txt index fec71cb..14d038b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,11 +3,12 @@ cmake>=3.26 decorator numpy<2.0.0 packaging +pip pybind11 pyyaml scipy setuptools>=64 setuptools-scm>=8 -torch_npu -torch >= 2.5.1 +torch>=2.5.1 torchvision<0.21.0 +wheel diff --git a/tests/compile/__init__.py b/tests/compile/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/compile/test_simple.py b/tests/compile/test_simple.py new file mode 100644 index 0000000..cb54422 --- /dev/null +++ b/tests/compile/test_simple.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Test the piecewise compilation with a simple model so that we +can exactly calculate the expected output and side effects. +""" + +import pytest +import torch +from torch import nn +from torch.library import Library +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.utils import direct_register_custom_op + +global_counter = 0 + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + global global_counter + global_counter += 1 + print(f"{global_counter=}") + out.copy_(q) + out[0] += 1 + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + dispatch_key="PrivateUse1", + target_lib=silly_lib, +) + + +@support_torch_compile +class SillyModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overall effect: + x += 1 + x[0] += 2 + global_counter += 2 + """ + x = x + 1 + x = x + 2 + out = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, out) + x = out + x = x - 2 + x = x - 1 + out = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, out) + x = out + x = x + 1 + return x + + +@pytest.mark.skipif(True, reason="requires unreleased components") +def test_simple_piecewise_compile(): + + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_inductor=False, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], + )) + vllm_config.compilation_config.pass_config.enable_fusion = False + with set_current_vllm_config(vllm_config): + model = SillyModel(vllm_config=vllm_config, prefix="") + + inputs = torch.randn(100).npu() + + with compilation_counter.expect( + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=5, # 2 * num_layers + 1 + num_piecewise_capturable_graphs_seen=3, # 1 + num_layers + num_backend_compilations=3, # num_piecewise_capturable_graphs_seen + num_cudagraph_caputured= + 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + + model(inputs) + + model(torch.randn(2).npu()) + model(torch.randn(1).npu()) + + input = torch.zeros(2).npu() + global global_counter + global_counter = 0 + output = model(input) + assert global_counter == 2 + assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0])) + + +if __name__ == "__main__": + test_simple_piecewise_compile() diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index dfc6675..a41996d 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -47,6 +47,7 @@ def test_models_distributed(model: str, dtype=dtype, tensor_parallel_size=4, distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index 9249c33..7ccd9cf 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -50,7 +50,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None: with VllmRunner(model, max_model_len=8192, dtype=dtype, - enforce_eager=False, + enforce_eager=True, gpu_memory_utilization=0.7) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 39878ce..03d0bcc 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,8 @@ import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -31,6 +33,7 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill class AscendAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True @staticmethod def get_name() -> str: @@ -198,6 +201,7 @@ class AscendAttentionBackendImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, + trace_flag: bool = True, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: @@ -215,98 +219,150 @@ class AscendAttentionBackendImpl(AttentionImpl): shape = [batch_size * seq_len, num_heads, head_size] """ num_tokens = query.shape[0] - output = torch.empty(num_tokens, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device) - - if attn_metadata is None: - # Profiling run. - return output.view(num_tokens, self.hidden_size) - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") - # View q k v to BSH. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - # TODO: Remove this contiguous in the future. - value = value.contiguous() - - if kv_cache.numel() > 0: - if self.key_cache is None: - self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] - slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache(key=key, - value=value, - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slots) - - if hasattr(layer, 'quant_method'): - # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata - pass - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - mask = attn_metadata.attn_mask - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - block_tables = attn_metadata.block_tables - torch_npu._npu_paged_attention(query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=block_tables, - context_lens=attn_metadata.seq_lens, - out=output) - # Normal V1 situation. + if output is None: + output = torch.empty(num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + if trace_flag: + torch.ops.vllm.unified_ascend_attention_with_output( + query=query, + key=key, + value=value, + output=output, + layer_name=layer.layer_name) else: - # use chunked prefill for head size 192 scenario, like deepseek - # paged_attention_splitfuse maybe crash at such scenario - # TODO: vanilla path will be removed after the kernel support - # head_size 192 scenario - if self.head_size == 192: - cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() - cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() - cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu") - cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu") - cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) - cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) - max_seqlen_q = torch.max(attn_metadata.query_lens) - max_seqlen_k = torch.max(attn_metadata.seq_lens) - vanilla_chunked_prefill(output, query, self.key_cache, - self.value_cache, - attn_metadata.block_tables, - cu_seqlen_q, cu_seqlen_k, max_seqlen_q, - max_seqlen_k, self.scale, None, True) - else: - torch_npu._npu_paged_attention_splitfuse( + num_tokens = query.shape[0] + if attn_metadata is None: + return output.view(num_tokens, self.hidden_size) + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + # View q k v to BSH. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + # TODO: Remove this contiguous in the future. + value = value.contiguous() + + if kv_cache.numel() > 0: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache(key=key, + value=value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots) + + if hasattr(layer, 'quant_method'): + # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata + pass + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + block_tables = attn_metadata.block_tables + torch_npu._npu_paged_attention( query=query, key_cache=self.key_cache, value_cache=self.value_cache, - mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, scale_value=self.scale, + block_table=block_tables, + context_lens=attn_metadata.seq_lens, out=output) + # Normal V1 situation. + else: + # use chunked prefill for head size 192 scenario, like deepseek + # paged_attention_splitfuse maybe crash at such scenario + # TODO: vanilla path will be removed after the kernel support + # head_size 192 scenario + if self.head_size == 192: + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu") + cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu") + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) + max_seqlen_q = torch.max(attn_metadata.query_lens) + max_seqlen_k = torch.max(attn_metadata.seq_lens) + vanilla_chunked_prefill(output, query, self.key_cache, + self.value_cache, + attn_metadata.block_tables, + cu_seqlen_q, cu_seqlen_k, + max_seqlen_q, max_seqlen_k, + self.scale, None, True) + else: + # use paged attention + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) return output.view(num_tokens, self.hidden_size) + + +def unified_ascend_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + kv_cache, + attn_metadata, + output, + trace_flag=False) + return + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_ascend_attention_with_output", + op_func=unified_ascend_attention_with_output, + mutates_args=["output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 1947799..71d86a2 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -14,8 +14,37 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # + +import torch +import torch_npu # noqa: F401 + import vllm_ascend.ops.activation # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.rotary_embedding # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa + + +class dummyFusionOp: + default = None + + def __init__(self, name=""): + self.name = name + + +def register_dummy_fusion_op() -> None: + torch.cuda.CUDAGraph = torch_npu.npu.NPUGraph + torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm") + torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm") + torch.ops._C.static_scaled_fp8_quant = dummyFusionOp( + name="static_scaled_fp8_quant") + torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp( + name="dynamic_scaled_fp8_quant") + torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( + name="dynamic_per_token_scaled_fp8_quant") + torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp( + name="rms_norm_static_fp8_quant") + torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( + name="fused_add_rms_norm_static_fp8_quant") + torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp( + name="rms_norm_dynamic_per_token_quant") diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ccaee9d..c82d4e8 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -114,11 +114,33 @@ class NPUPlatform(Platform): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel # noqa: E402 compilation_config = vllm_config.compilation_config - if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION: + + enforce_eager_flag = False + # Check whether the eager mode is configured + try: + enforce_eager_flag = vllm_config.model_config.enforce_eager + except Exception: logger.warning( - "Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION", + "There is currently no enforce_eager mode configured, the default value of enforce_eager=False is used" + ) + + if enforce_eager_flag or compilation_config.level == CompilationLevel.NO_COMPILATION: + logger.warning( + "Compilation level PIECEWISE is not enable on NPU now, current compilation level to NO_COMPILATION" + ) + compilation_config.level = CompilationLevel.NO_COMPILATION + elif compilation_config.level != CompilationLevel.PIECEWISE: + logger.warning( + "Compilation level %s is not enable on NPU now, forcing compilation level to NO_COMPILATION", compilation_config.level) compilation_config.level = CompilationLevel.NO_COMPILATION + else: + logger.info( + "Compilation level PIECEWISE is enable on NPU now, But use_inductor is no support, only use npu_graph now" + ) + compilation_config.use_inductor = False + compilation_config.splitting_ops.extend( + ["vllm.unified_ascend_attention_with_output"]) if vllm_config.additional_config is not None: enable_graph_mode = vllm_config.additional_config.get( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8be8e23..ca157ab 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -19,7 +19,10 @@ import gc import os +import time import weakref +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np @@ -28,7 +31,7 @@ import torch import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -58,6 +61,43 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") +@dataclass +class GraphCaptureContext: + stream: torch.npu.Stream + + +@contextmanager +def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the NPU graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current NPU stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + graph_capture_context = GraphCaptureContext( + torch.npu.Stream(device=device)) + stream = graph_capture_context.stream + + # we use nullcontext now + maybe_ca_context = nullcontext() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.npu.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.npu.stream(stream), maybe_ca_context: + yield graph_capture_context + + class NPUModelRunner: def __init__(self, vllm_config: VllmConfig, device: torch.device): @@ -229,6 +269,12 @@ class NPUModelRunner: device="cpu") self.attn_mask = None self.attn_state = None + self.use_npu_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not self.model_config.enforce_eager) + self.npugraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) # NOTE: Pre-construct a mask matrix to improve the efficiency of # attention mask construction during inference. @@ -724,19 +770,19 @@ class NPUModelRunner: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) @torch.inference_mode() - def _dummy_run(self) -> torch.Tensor: + def _dummy_run(self, num_tokens: int) -> torch.Tensor: model = self.model if self.is_multimodal_model: input_ids = None - inputs_embeds = self.inputs_embeds[:self.max_num_tokens] + inputs_embeds = self.inputs_embeds[:num_tokens] else: - input_ids = self.input_ids[:self.max_num_tokens] + input_ids = self.input_ids[:num_tokens] inputs_embeds = None if self.uses_mrope: - positions = self.mrope_positions[:, :self.max_num_tokens] + positions = self.mrope_positions[:, :num_tokens] else: - positions = self.input_positions_cpu[:self.max_num_tokens] + positions = self.positions[:num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -744,17 +790,17 @@ class NPUModelRunner: if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, + batch_size=num_tokens, dtype=self.dtype, device=self.device)) intermediate_tensors = IntermediateTensors({ - k: v[:self.max_num_tokens] + k: v[:num_tokens] for k, v in self.intermediate_tensors.items() }) with set_forward_context(None, self.vllm_config): hidden_states = model(input_ids=input_ids, - positions=positions.to(self.device), + positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states @@ -787,7 +833,7 @@ class NPUModelRunner: ] # Trigger compilation for general shape. - hidden_states = self._dummy_run() + hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] @@ -892,3 +938,31 @@ class NPUModelRunner: f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec + + def capture_model(self) -> None: + if not self.use_npu_graph: + logger.warning( + "Skipping NPU graph capture. Please add " + "-O %s to use NPU graphs.", CompilationLevel.PIECEWISE) + return + + start_time = time.perf_counter() + start_free_npu_memory = torch.npu.mem_get_info()[0] + + # Trigger NPU graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for num_tokens in reversed(self.npugraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens) + self._dummy_run(num_tokens) + + end_time = time.perf_counter() + end_free_npu_memory = torch.npu.mem_get_info()[0] + elapsed_time = end_time - start_time + npu_graph_size = start_free_npu_memory - end_free_npu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, npu_graph_size / (1 << 30)) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 7e98d4b..3dacd13 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -23,6 +23,7 @@ from typing import Dict, List, Optional import torch import torch.nn as nn import torch_npu +from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from vllm import envs from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, @@ -65,7 +66,9 @@ class NPUWorker(WorkerBase): from vllm_ascend.utils import adapt_patch adapt_patch() # Register ops when worker init. - from vllm_ascend import ops # noqa: F401 + from vllm_ascend import ops + ops.register_dummy_fusion_op() + _register_atb_extensions() super().__init__(vllm_config=vllm_config, local_rank=local_rank, @@ -179,8 +182,17 @@ class NPUWorker(WorkerBase): self.model_runner.load_model() def compile_or_warm_up_model(self) -> None: + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: - logger.warning("Graph capture is not supported on NPU.") + warmup_sizes = [ + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size) + if not self.model_config.enforce_eager: + self.model_runner.capture_model() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed)