support aclgraph (#426)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
28
.github/workflows/vllm_ascend_test.yaml
vendored
28
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -115,24 +115,26 @@ jobs:
|
||||
- name: Install vllm-project/vllm-ascend
|
||||
run: |
|
||||
pip install -r requirements-dev.txt
|
||||
pip install -e .
|
||||
|
||||
- name: Run vllm-project/vllm-ascend test on V0 engine
|
||||
env:
|
||||
VLLM_USE_V1: 0
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||
pytest -sv tests/singlecard/test_offline_inference.py
|
||||
pytest -sv tests/ops
|
||||
else
|
||||
pytest -sv tests/multicard/test_offline_inference_distributed.py
|
||||
pytest -sv tests/ops
|
||||
fi
|
||||
pip install -v --no-build-isolation -e .
|
||||
|
||||
- name: Run vllm-project/vllm-ascend test for V1 Engine
|
||||
env:
|
||||
VLLM_USE_V1: 1
|
||||
VLLM_WORKER_MULTIPROC_METHOD: spawn
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||
pytest -sv tests/singlecard/test_offline_inference.py
|
||||
pytest -sv tests/ops
|
||||
pytest -sv tests/compile
|
||||
else
|
||||
pytest -sv tests/multicard/test_offline_inference_distributed.py
|
||||
pytest -sv tests/ops
|
||||
pytest -sv tests/compile
|
||||
fi
|
||||
|
||||
- name: Run vllm-project/vllm-ascend test on V0 engine
|
||||
env:
|
||||
VLLM_USE_V1: 0
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||
pytest -sv tests/singlecard/test_offline_inference.py
|
||||
|
||||
19
csrc/ops.h
19
csrc/ops.h
@@ -21,6 +21,7 @@
|
||||
|
||||
#include <vector>
|
||||
#include "kernels/types.h"
|
||||
#include "torch_npu/csrc/aten/common/from_blob.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
|
||||
@@ -29,4 +30,20 @@ namespace vllm_ascend {
|
||||
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
|
||||
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
|
||||
uint32_t aivNum);
|
||||
}
|
||||
|
||||
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
|
||||
if (!tensor.is_privateuseone()) {
|
||||
throw std::runtime_error("Tensor must be on NPU device");
|
||||
}
|
||||
// Get the raw data pointer
|
||||
void* data_ptr = tensor.data_ptr();
|
||||
// Get tensor sizes and strides
|
||||
std::vector<int64_t> sizes = tensor.sizes().vec();
|
||||
std::vector<int64_t> strides = tensor.strides().vec();
|
||||
// Get tensor options (dtype, device)
|
||||
auto options = tensor.options();
|
||||
// Create a new tensor from the raw data pointer
|
||||
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
|
||||
return new_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,8 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
{
|
||||
// vLLM-Ascend custom ops
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX style rotary embedding to query and key.
|
||||
|
||||
@@ -11,8 +11,8 @@ requires = [
|
||||
"scipy",
|
||||
"setuptools>=64",
|
||||
"setuptools-scm>=8",
|
||||
"torch_npu",
|
||||
"torch >= 2.5.1",
|
||||
"torch_npu==2.5.1rc1",
|
||||
"torch>=2.5.1",
|
||||
"torchvision<0.21.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
-r requirements-lint.txt
|
||||
-r requirements.txt
|
||||
modelscope
|
||||
pytest >= 6.0
|
||||
pytest-asyncio
|
||||
|
||||
@@ -3,11 +3,12 @@ cmake>=3.26
|
||||
decorator
|
||||
numpy<2.0.0
|
||||
packaging
|
||||
pip
|
||||
pybind11
|
||||
pyyaml
|
||||
scipy
|
||||
setuptools>=64
|
||||
setuptools-scm>=8
|
||||
torch_npu
|
||||
torch >= 2.5.1
|
||||
torch>=2.5.1
|
||||
torchvision<0.21.0
|
||||
wheel
|
||||
|
||||
0
tests/compile/__init__.py
Normal file
0
tests/compile/__init__.py
Normal file
118
tests/compile/test_simple.py
Normal file
118
tests/compile/test_simple.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
global global_counter
|
||||
global_counter += 1
|
||||
print(f"{global_counter=}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Overall effect:
|
||||
x += 1
|
||||
x[0] += 2
|
||||
global_counter += 2
|
||||
"""
|
||||
x = x + 1
|
||||
x = x + 2
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x - 2
|
||||
x = x - 1
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="requires unreleased components")
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor=False,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config.compilation_config.pass_config.enable_fusion = False
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||
|
||||
inputs = torch.randn(100).npu()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
|
||||
model(inputs)
|
||||
|
||||
model(torch.randn(2).npu())
|
||||
model(torch.randn(1).npu())
|
||||
|
||||
input = torch.zeros(2).npu()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simple_piecewise_compile()
|
||||
@@ -47,6 +47,7 @@ def test_models_distributed(model: str,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=4,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=False,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
@@ -31,6 +33,7 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -198,6 +201,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
trace_flag: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
@@ -215,98 +219,150 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
if trace_flag:
|
||||
torch.ops.vllm.unified_ascend_attention_with_output(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
else:
|
||||
# use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables,
|
||||
cu_seqlen_q, cu_seqlen_k, max_seqlen_q,
|
||||
max_seqlen_k, self.scale, None, True)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
# use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables,
|
||||
cu_seqlen_q, cu_seqlen_k,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
self.scale, None, True)
|
||||
else:
|
||||
# use paged attention
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
def unified_ascend_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
trace_flag=False)
|
||||
return
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_ascend_attention_with_output",
|
||||
op_func=unified_ascend_attention_with_output,
|
||||
mutates_args=["output"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
@@ -14,8 +14,37 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
import vllm_ascend.ops.activation # noqa
|
||||
import vllm_ascend.ops.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.rotary_embedding # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
|
||||
|
||||
class dummyFusionOp:
|
||||
default = None
|
||||
|
||||
def __init__(self, name=""):
|
||||
self.name = name
|
||||
|
||||
|
||||
def register_dummy_fusion_op() -> None:
|
||||
torch.cuda.CUDAGraph = torch_npu.npu.NPUGraph
|
||||
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
|
||||
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
|
||||
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
|
||||
name="static_scaled_fp8_quant")
|
||||
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_scaled_fp8_quant")
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_per_token_scaled_fp8_quant")
|
||||
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="rms_norm_static_fp8_quant")
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="fused_add_rms_norm_static_fp8_quant")
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
||||
name="rms_norm_dynamic_per_token_quant")
|
||||
|
||||
@@ -114,11 +114,33 @@ class NPUPlatform(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationLevel # noqa: E402
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
|
||||
|
||||
enforce_eager_flag = False
|
||||
# Check whether the eager mode is configured
|
||||
try:
|
||||
enforce_eager_flag = vllm_config.model_config.enforce_eager
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
|
||||
"There is currently no enforce_eager mode configured, the default value of enforce_eager=False is used"
|
||||
)
|
||||
|
||||
if enforce_eager_flag or compilation_config.level == CompilationLevel.NO_COMPILATION:
|
||||
logger.warning(
|
||||
"Compilation level PIECEWISE is not enable on NPU now, current compilation level to NO_COMPILATION"
|
||||
)
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
elif compilation_config.level != CompilationLevel.PIECEWISE:
|
||||
logger.warning(
|
||||
"Compilation level %s is not enable on NPU now, forcing compilation level to NO_COMPILATION",
|
||||
compilation_config.level)
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
else:
|
||||
logger.info(
|
||||
"Compilation level PIECEWISE is enable on NPU now, But use_inductor is no support, only use npu_graph now"
|
||||
)
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops.extend(
|
||||
["vllm.unified_ascend_attention_with_output"])
|
||||
|
||||
if vllm_config.additional_config is not None:
|
||||
enable_graph_mode = vllm_config.additional_config.get(
|
||||
|
||||
@@ -19,7 +19,10 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import weakref
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -28,7 +31,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
@@ -58,6 +61,43 @@ else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch.npu.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(device: torch.device):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the NPU graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current NPU stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
graph_capture_context = GraphCaptureContext(
|
||||
torch.npu.Stream(device=device))
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# we use nullcontext now
|
||||
maybe_ca_context = nullcontext()
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch.npu.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.npu.stream(stream), maybe_ca_context:
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
class NPUModelRunner:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
@@ -229,6 +269,12 @@ class NPUModelRunner:
|
||||
device="cpu")
|
||||
self.attn_mask = None
|
||||
self.attn_state = None
|
||||
self.use_npu_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
self.npugraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
||||
# attention mask construction during inference.
|
||||
@@ -724,19 +770,19 @@ class NPUModelRunner:
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(self) -> torch.Tensor:
|
||||
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
|
||||
model = self.model
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:self.max_num_tokens]
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:self.max_num_tokens]
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :self.max_num_tokens]
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.input_positions_cpu[:self.max_num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@@ -744,17 +790,17 @@ class NPUModelRunner:
|
||||
if self.intermediate_tensors is None:
|
||||
self.intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors(
|
||||
batch_size=self.max_num_tokens,
|
||||
batch_size=num_tokens,
|
||||
dtype=self.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:self.max_num_tokens]
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
hidden_states = model(input_ids=input_ids,
|
||||
positions=positions.to(self.device),
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
@@ -787,7 +833,7 @@ class NPUModelRunner:
|
||||
]
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run()
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
@@ -892,3 +938,31 @@ class NPUModelRunner:
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if not self.use_npu_graph:
|
||||
logger.warning(
|
||||
"Skipping NPU graph capture. Please add "
|
||||
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||
|
||||
# Trigger NPU graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
for num_tokens in reversed(self.npugraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, npu_graph_size / (1 << 30))
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Dict, List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
@@ -65,7 +66,9 @@ class NPUWorker(WorkerBase):
|
||||
from vllm_ascend.utils import adapt_patch
|
||||
adapt_patch()
|
||||
# Register ops when worker init.
|
||||
from vllm_ascend import ops # noqa: F401
|
||||
from vllm_ascend import ops
|
||||
ops.register_dummy_fusion_op()
|
||||
_register_atb_extensions()
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
@@ -179,8 +182,17 @@ class NPUWorker(WorkerBase):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
logger.warning("Graph capture is not supported on NPU.")
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
Reference in New Issue
Block a user