support aclgraph (#426)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.

1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.

This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default

---------

Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Bug Hunter Yan
2025-04-23 20:56:24 +08:00
committed by GitHub
parent 5c6d05a59e
commit 05bdcbeae4
15 changed files with 454 additions and 119 deletions

View File

@@ -115,24 +115,26 @@ jobs:
- name: Install vllm-project/vllm-ascend
run: |
pip install -r requirements-dev.txt
pip install -e .
- name: Run vllm-project/vllm-ascend test on V0 engine
env:
VLLM_USE_V1: 0
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/ops
else
pytest -sv tests/multicard/test_offline_inference_distributed.py
pytest -sv tests/ops
fi
pip install -v --no-build-isolation -e .
- name: Run vllm-project/vllm-ascend test for V1 Engine
env:
VLLM_USE_V1: 1
VLLM_WORKER_MULTIPROC_METHOD: spawn
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/ops
pytest -sv tests/compile
else
pytest -sv tests/multicard/test_offline_inference_distributed.py
pytest -sv tests/ops
pytest -sv tests/compile
fi
- name: Run vllm-project/vllm-ascend test on V0 engine
env:
VLLM_USE_V1: 0
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/test_offline_inference.py

View File

@@ -21,6 +21,7 @@
#include <vector>
#include "kernels/types.h"
#include "torch_npu/csrc/aten/common/from_blob.h"
namespace vllm_ascend {
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
@@ -29,4 +30,20 @@ namespace vllm_ascend {
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum);
}
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
if (!tensor.is_privateuseone()) {
throw std::runtime_error("Tensor must be on NPU device");
}
// Get the raw data pointer
void* data_ptr = tensor.data_ptr();
// Get tensor sizes and strides
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
// Get tensor options (dtype, device)
auto options = tensor.options();
// Create a new tensor from the raw data pointer
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
}

View File

@@ -103,6 +103,8 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
TORCH_LIBRARY_EXPAND(_C, ops)
{
// vLLM-Ascend custom ops
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
// Rotary embedding
// Apply GPT-NeoX style rotary embedding to query and key.

View File

@@ -11,8 +11,8 @@ requires = [
"scipy",
"setuptools>=64",
"setuptools-scm>=8",
"torch_npu",
"torch >= 2.5.1",
"torch_npu==2.5.1rc1",
"torch>=2.5.1",
"torchvision<0.21.0",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,5 @@
-r requirements-lint.txt
-r requirements.txt
modelscope
pytest >= 6.0
pytest-asyncio

View File

@@ -3,11 +3,12 @@ cmake>=3.26
decorator
numpy<2.0.0
packaging
pip
pybind11
pyyaml
scipy
setuptools>=64
setuptools-scm>=8
torch_npu
torch >= 2.5.1
torch>=2.5.1
torchvision<0.21.0
wheel

View File

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import pytest
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.utils import direct_register_custom_op
global_counter = 0
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
dispatch_key="PrivateUse1",
target_lib=silly_lib,
)
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
global_counter += 2
"""
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x + 1
return x
@pytest.mark.skipif(True, reason="requires unreleased components")
def test_simple_piecewise_compile():
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor=False,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
vllm_config.compilation_config.pass_config.enable_fusion = False
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix="")
inputs = torch.randn(100).npu()
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
model(inputs)
model(torch.randn(2).npu())
model(torch.randn(1).npu())
input = torch.zeros(2).npu()
global global_counter
global_counter = 0
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
if __name__ == "__main__":
test_simple_piecewise_compile()

View File

@@ -47,6 +47,7 @@ def test_models_distributed(model: str,
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -50,7 +50,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=False,
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -24,6 +24,8 @@ import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -31,6 +33,7 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
@@ -198,6 +201,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
@@ -215,98 +219,150 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if attn_metadata is None:
# Profiling run.
return output.view(num_tokens, self.hidden_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if trace_flag:
torch.ops.vllm.unified_ascend_attention_with_output(
query=query,
key=key,
value=value,
output=output,
layer_name=layer.layer_name)
else:
# use chunked prefill for head size 192 scenario, like deepseek
# paged_attention_splitfuse maybe crash at such scenario
# TODO: vanilla path will be removed after the kernel support
# head_size 192 scenario
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)
max_seqlen_k = torch.max(attn_metadata.seq_lens)
vanilla_chunked_prefill(output, query, self.key_cache,
self.value_cache,
attn_metadata.block_tables,
cu_seqlen_q, cu_seqlen_k, max_seqlen_q,
max_seqlen_k, self.scale, None, True)
else:
torch_npu._npu_paged_attention_splitfuse(
num_tokens = query.shape[0]
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.
else:
# use chunked prefill for head size 192 scenario, like deepseek
# paged_attention_splitfuse maybe crash at such scenario
# TODO: vanilla path will be removed after the kernel support
# head_size 192 scenario
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)
max_seqlen_k = torch.max(attn_metadata.seq_lens)
vanilla_chunked_prefill(output, query, self.key_cache,
self.value_cache,
attn_metadata.block_tables,
cu_seqlen_q, cu_seqlen_k,
max_seqlen_q, max_seqlen_k,
self.scale, None, True)
else:
# use paged attention
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output.view(num_tokens, self.hidden_size)
def unified_ascend_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output,
trace_flag=False)
return
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="unified_ascend_attention_with_output",
op_func=unified_ascend_attention_with_output,
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -14,8 +14,37 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import torch_npu # noqa: F401
import vllm_ascend.ops.activation # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.rotary_embedding # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
class dummyFusionOp:
default = None
def __init__(self, name=""):
self.name = name
def register_dummy_fusion_op() -> None:
torch.cuda.CUDAGraph = torch_npu.npu.NPUGraph
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
name="static_scaled_fp8_quant")
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
name="dynamic_scaled_fp8_quant")
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
name="rms_norm_static_fp8_quant")
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")

View File

@@ -114,11 +114,33 @@ class NPUPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel # noqa: E402
compilation_config = vllm_config.compilation_config
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
enforce_eager_flag = False
# Check whether the eager mode is configured
try:
enforce_eager_flag = vllm_config.model_config.enforce_eager
except Exception:
logger.warning(
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
"There is currently no enforce_eager mode configured, the default value of enforce_eager=False is used"
)
if enforce_eager_flag or compilation_config.level == CompilationLevel.NO_COMPILATION:
logger.warning(
"Compilation level PIECEWISE is not enable on NPU now, current compilation level to NO_COMPILATION"
)
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.level != CompilationLevel.PIECEWISE:
logger.warning(
"Compilation level %s is not enable on NPU now, forcing compilation level to NO_COMPILATION",
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
else:
logger.info(
"Compilation level PIECEWISE is enable on NPU now, But use_inductor is no support, only use npu_graph now"
)
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(
["vllm.unified_ascend_attention_with_output"])
if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(

View File

@@ -19,7 +19,10 @@
import gc
import os
import time
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
@@ -28,7 +31,7 @@ import torch
import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
@@ -58,6 +61,43 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
@dataclass
class GraphCaptureContext:
stream: torch.npu.Stream
@contextmanager
def graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the NPU graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current NPU stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
graph_capture_context = GraphCaptureContext(
torch.npu.Stream(device=device))
stream = graph_capture_context.stream
# we use nullcontext now
maybe_ca_context = nullcontext()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.npu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.npu.stream(stream), maybe_ca_context:
yield graph_capture_context
class NPUModelRunner:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
@@ -229,6 +269,12 @@ class NPUModelRunner:
device="cpu")
self.attn_mask = None
self.attn_state = None
self.use_npu_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.npugraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# NOTE: Pre-construct a mask matrix to improve the efficiency of
# attention mask construction during inference.
@@ -724,19 +770,19 @@ class NPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode()
def _dummy_run(self) -> torch.Tensor:
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:self.max_num_tokens]
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:self.max_num_tokens]
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :self.max_num_tokens]
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.input_positions_cpu[:self.max_num_tokens]
positions = self.positions[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -744,17 +790,17 @@ class NPUModelRunner:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
batch_size=num_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:self.max_num_tokens]
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(input_ids=input_ids,
positions=positions.to(self.device),
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
@@ -787,7 +833,7 @@ class NPUModelRunner:
]
# Trigger compilation for general shape.
hidden_states = self._dummy_run()
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
@@ -892,3 +938,31 @@ class NPUModelRunner:
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec
def capture_model(self) -> None:
if not self.use_npu_graph:
logger.warning(
"Skipping NPU graph capture. Please add "
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
# Trigger NPU graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.npugraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time
npu_graph_size = start_free_npu_memory - end_free_npu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))

View File

@@ -23,6 +23,7 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch_npu
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
@@ -65,7 +66,9 @@ class NPUWorker(WorkerBase):
from vllm_ascend.utils import adapt_patch
adapt_patch()
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
@@ -179,8 +182,17 @@ class NPUWorker(WorkerBase):
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
logger.warning("Graph capture is not supported on NPU.")
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)