support aclgraph (#426)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.

1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.

This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default

---------

Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Bug Hunter Yan
2025-04-23 20:56:24 +08:00
committed by GitHub
parent 5c6d05a59e
commit 05bdcbeae4
15 changed files with 454 additions and 119 deletions

View File

@@ -24,6 +24,8 @@ import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -31,6 +33,7 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
@@ -198,6 +201,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
@@ -215,98 +219,150 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if attn_metadata is None:
# Profiling run.
return output.view(num_tokens, self.hidden_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if trace_flag:
torch.ops.vllm.unified_ascend_attention_with_output(
query=query,
key=key,
value=value,
output=output,
layer_name=layer.layer_name)
else:
# use chunked prefill for head size 192 scenario, like deepseek
# paged_attention_splitfuse maybe crash at such scenario
# TODO: vanilla path will be removed after the kernel support
# head_size 192 scenario
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)
max_seqlen_k = torch.max(attn_metadata.seq_lens)
vanilla_chunked_prefill(output, query, self.key_cache,
self.value_cache,
attn_metadata.block_tables,
cu_seqlen_q, cu_seqlen_k, max_seqlen_q,
max_seqlen_k, self.scale, None, True)
else:
torch_npu._npu_paged_attention_splitfuse(
num_tokens = query.shape[0]
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.
else:
# use chunked prefill for head size 192 scenario, like deepseek
# paged_attention_splitfuse maybe crash at such scenario
# TODO: vanilla path will be removed after the kernel support
# head_size 192 scenario
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)
max_seqlen_k = torch.max(attn_metadata.seq_lens)
vanilla_chunked_prefill(output, query, self.key_cache,
self.value_cache,
attn_metadata.block_tables,
cu_seqlen_q, cu_seqlen_k,
max_seqlen_q, max_seqlen_k,
self.scale, None, True)
else:
# use paged attention
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output.view(num_tokens, self.hidden_size)
def unified_ascend_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output,
trace_flag=False)
return
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="unified_ascend_attention_with_output",
op_func=unified_ascend_attention_with_output,
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -14,8 +14,37 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import torch_npu # noqa: F401
import vllm_ascend.ops.activation # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.rotary_embedding # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
class dummyFusionOp:
default = None
def __init__(self, name=""):
self.name = name
def register_dummy_fusion_op() -> None:
torch.cuda.CUDAGraph = torch_npu.npu.NPUGraph
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
name="static_scaled_fp8_quant")
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
name="dynamic_scaled_fp8_quant")
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
name="rms_norm_static_fp8_quant")
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")

View File

@@ -114,11 +114,33 @@ class NPUPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel # noqa: E402
compilation_config = vllm_config.compilation_config
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
enforce_eager_flag = False
# Check whether the eager mode is configured
try:
enforce_eager_flag = vllm_config.model_config.enforce_eager
except Exception:
logger.warning(
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
"There is currently no enforce_eager mode configured, the default value of enforce_eager=False is used"
)
if enforce_eager_flag or compilation_config.level == CompilationLevel.NO_COMPILATION:
logger.warning(
"Compilation level PIECEWISE is not enable on NPU now, current compilation level to NO_COMPILATION"
)
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.level != CompilationLevel.PIECEWISE:
logger.warning(
"Compilation level %s is not enable on NPU now, forcing compilation level to NO_COMPILATION",
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
else:
logger.info(
"Compilation level PIECEWISE is enable on NPU now, But use_inductor is no support, only use npu_graph now"
)
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(
["vllm.unified_ascend_attention_with_output"])
if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(

View File

@@ -19,7 +19,10 @@
import gc
import os
import time
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
@@ -28,7 +31,7 @@ import torch
import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
@@ -58,6 +61,43 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
@dataclass
class GraphCaptureContext:
stream: torch.npu.Stream
@contextmanager
def graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the NPU graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current NPU stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
graph_capture_context = GraphCaptureContext(
torch.npu.Stream(device=device))
stream = graph_capture_context.stream
# we use nullcontext now
maybe_ca_context = nullcontext()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.npu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.npu.stream(stream), maybe_ca_context:
yield graph_capture_context
class NPUModelRunner:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
@@ -229,6 +269,12 @@ class NPUModelRunner:
device="cpu")
self.attn_mask = None
self.attn_state = None
self.use_npu_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.npugraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# NOTE: Pre-construct a mask matrix to improve the efficiency of
# attention mask construction during inference.
@@ -724,19 +770,19 @@ class NPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode()
def _dummy_run(self) -> torch.Tensor:
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:self.max_num_tokens]
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:self.max_num_tokens]
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :self.max_num_tokens]
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.input_positions_cpu[:self.max_num_tokens]
positions = self.positions[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -744,17 +790,17 @@ class NPUModelRunner:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
batch_size=num_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:self.max_num_tokens]
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(input_ids=input_ids,
positions=positions.to(self.device),
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
@@ -787,7 +833,7 @@ class NPUModelRunner:
]
# Trigger compilation for general shape.
hidden_states = self._dummy_run()
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
@@ -892,3 +938,31 @@ class NPUModelRunner:
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec
def capture_model(self) -> None:
if not self.use_npu_graph:
logger.warning(
"Skipping NPU graph capture. Please add "
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
# Trigger NPU graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.npugraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time
npu_graph_size = start_free_npu_memory - end_free_npu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))

View File

@@ -23,6 +23,7 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch_npu
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
@@ -65,7 +66,9 @@ class NPUWorker(WorkerBase):
from vllm_ascend.utils import adapt_patch
adapt_patch()
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
@@ -179,8 +182,17 @@ class NPUWorker(WorkerBase):
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
logger.warning("Graph capture is not supported on NPU.")
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)