[V1][Core] Add support for V1 Engine (#295)

### What this PR does / why we need it?
Add support for V1 Engine.

Please note that this is just the initial version, and there may be some
places need to be fixed or optimized in the future, feel free to leave
some comments to us.

### Does this PR introduce _any_ user-facing change?

To use V1 Engine on NPU device, you need to set the env variable shown
below:

```bash
export VLLM_USE_V1=1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
```

If you are using vllm for offline inferencing, you must add a `__main__`
guard like:

```bash
if __name__ == '__main__':

    llm = vllm.LLM(...)
```

Find more details
[here](https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing).

### How was this patch tested?
I have tested the online serving with `Qwen2.5-7B-Instruct` using this
command:

```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --max_model_len 26240
```

Query the model with input prompts:

```bash
curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "prompt": "The future of AI is",
        "max_tokens": 7,
        "temperature": 0
    }'
```

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: didongli182 <didongli@huawei.com>
This commit is contained in:
Shanshan Shen
2025-03-20 19:34:44 +08:00
committed by GitHub
parent 663dca7578
commit c06af8b2e0
13 changed files with 1385 additions and 38 deletions

View File

@@ -44,7 +44,7 @@ Find more about how to setup your environment step by step in [here](docs/source
## Getting Started
> [!NOTE]
> Currently, we are actively collaborating with the vLLM community to support the Ascend backend plugin, once supported you can use one line command `pip install vllm vllm-ascend` to compelete installation.
> Currently, we are actively collaborating with the vLLM community to support the Ascend backend plugin, once supported you can use one line command `pip install vllm vllm-ascend` to complete installation.
Installation from source code:
```bash

View File

@@ -64,7 +64,7 @@ docker run --rm \
-it $IMAGE bash
```
:::{dropdown} Click here to see "Install CANN manally"
:::{dropdown} Click here to see "Install CANN manually"
:animate: fade-in-slide-down
You can also install CANN manually:

View File

@@ -56,7 +56,7 @@ ray start --address='{head_node_ip}:{port_num}' --num-gpus=8 --node-ip-address={
```
:::{note}
If you're running DeepSeek V3/R1, please remove `quantization_config` section in `config.json` file since it's not supported by vllm-ascend currentlly.
If you're running DeepSeek V3/R1, please remove `quantization_config` section in `config.json` file since it's not supported by vllm-ascend currently.
:::
Start the vLLM server on head node:

View File

@@ -25,8 +25,8 @@
- Pin modelscope<1.23.0 on vLLM v0.7.3 to resolve: https://github.com/vllm-project/vllm/pull/13807
### Known issues
- In [some cases](https://github.com/vllm-project/vllm-ascend/issues/324), expecially when the input/output is very long, the accuracy of output may be incorrect. We are working on it. It'll be fixed in the next release.
- Improved and reduced the garbled code in model output. But if you still hit the issue, try to change the gerneration config value, such as `temperature`, and try again. There is also a knonwn issue shown below. Any [feedback](https://github.com/vllm-project/vllm-ascend/issues/267) is welcome. [#277](https://github.com/vllm-project/vllm-ascend/pull/277)
- In [some cases](https://github.com/vllm-project/vllm-ascend/issues/324), especially when the input/output is very long, the accuracy of output may be incorrect. We are working on it. It'll be fixed in the next release.
- Improved and reduced the garbled code in model output. But if you still hit the issue, try to change the generation config value, such as `temperature`, and try again. There is also a knonwn issue shown below. Any [feedback](https://github.com/vllm-project/vllm-ascend/issues/267) is welcome. [#277](https://github.com/vllm-project/vllm-ascend/pull/277)
## v0.7.1rc1
@@ -46,7 +46,7 @@ Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/v0.7.1-de
### Core
- Added the Ascend quantization config option, the implementation will comming soon. [#7](https://github.com/vllm-project/vllm-ascend/pull/7) [#73](https://github.com/vllm-project/vllm-ascend/pull/73)
- Added the Ascend quantization config option, the implementation will coming soon. [#7](https://github.com/vllm-project/vllm-ascend/pull/7) [#73](https://github.com/vllm-project/vllm-ascend/pull/73)
- Add silu_and_mul and rope ops and add mix ops into attention layer. [#18](https://github.com/vllm-project/vllm-ascend/pull/18)
### Other
@@ -58,5 +58,5 @@ Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/v0.7.1-de
### Known issues
- This release relies on an unreleased torch_npu version. It has been installed within official container image already. Please [install](https://vllm-ascend.readthedocs.io/en/v0.7.1rc1/installation.html) it manually if you are using non-container environment.
- There are logs like `No platform deteced, vLLM is running on UnspecifiedPlatform` or `Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")` shown when runing vllm-ascend. It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/12432) which will be included in v0.7.3 soon.
- There are logs like `# CPU blocks: 35064, # CPU blocks: 2730` shown when runing vllm-ascend which should be `# NPU blocks:` . It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/13378) which will be included in v0.7.3 soon.
- There are logs like `No platform detected, vLLM is running on UnspecifiedPlatform` or `Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")` shown when running vllm-ascend. It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/12432) which will be included in v0.7.3 soon.
- There are logs like `# CPU blocks: 35064, # CPU blocks: 2730` shown when running vllm-ascend which should be `# NPU blocks:` . It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/13378) which will be included in v0.7.3 soon.

View File

@@ -9,7 +9,7 @@
| Speculative decoding | ✅ | | | Basic functions available | Need fully test |
| Pooling | ✅ | | | Basic functions available(Bert) | Need fully test and add more models support|
| Enc-dec | ❌ | | | NA | Plan in 2025.06.30|
| Multi Modality | ✅ | | ✅ | Basic functions available(LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Improve perforamance, and add more models support |
| Multi Modality | ✅ | | ✅ | Basic functions available(LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Improve performance, and add more models support |
| LogProbs | ✅ | | | Basic functions available | Need fully test |
| Prompt logProbs | ✅ | | | Basic functions available | Need fully test |
| Async output | ✅ | | | Basic functions available | Need fully test |

View File

@@ -144,7 +144,7 @@ CODESPELL_EXCLUDES=(
)
CODESPELL_IGNORE_WORDS=(
'-L' 'CANN,NNAL,ASCEND'
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend'
)
# check spelling of specified files
@@ -152,7 +152,7 @@ spell_check() {
codespell "$@" "${CODESPELL_IGNORE_WORDS[@]}"
}
spell_check_all(){
spell_check_all() {
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
}
@@ -168,6 +168,7 @@ spell_check_changed() {
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
codespell "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
fi
}

31
test.py Normal file
View File

@@ -0,0 +1,31 @@
import os
import torch
import torch_npu # noqa: F401
device_id = 0
def _device_id_to_physical_device_id(device_id: int) -> int:
if "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty"
"string, which means Ascend NPU support is"
"disabled.")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
physical_device_id = _device_id_to_physical_device_id(device_id)
print("physical_device_id: " + str(physical_device_id))
# return torch.npu.get_device_name(physical_device_id)
torch.npu.get_device_name(device_id)
for k, v in os.environ.items():
if k == "ASCEND_RT_VISIBLE_DEVICES":
print(k)
print(v)

View File

@@ -20,7 +20,7 @@
#
if command -v actionlint &> /dev/null; then
# NOTE: avoid check .github/workflows/vllm_ascend_test.yaml becase sel-hosted runner `npu-arm64` is unknown
# NOTE: avoid check .github/workflows/vllm_ascend_test.yaml because sel-hosted runner `npu-arm64` is unknown
actionlint .github/workflows/*.yml .github/workflows/mypy.yaml
exit 0
elif [ -x ./actionlint ]; then

View File

@@ -0,0 +1,225 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
class AscendAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ASCEND"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
return AscendAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["AscendMetadata"]:
return AscendMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]
for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
@dataclass
class AscendMetadata:
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
context_lens: Optional[List[int]] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# TODO: Indicates whether there are only prefill requests.
# FlashAttention can be used when there are only prefill requests.
# FlashAttention has better performance than PageAtttention,
# but it does not support decode requests.
is_only_prefill: bool = False
attn_mask: Optional[torch.Tensor] = None
class AscendAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads * head_size]
key_cache = [num_blocks, block_size,
num_kv_heads * head_size]
value_cache = [num_blocks, block_size,
num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if attn_metadata is None:
# Profiling run.
return output.view(num_tokens, self.hidden_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
else:
if kv_cache.numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(num_blocks, block_size,
self.num_kv_heads, self.head_size)
value_cache = value_cache.view(num_blocks, block_size,
self.num_kv_heads,
self.head_size)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slots)
# use paged attention
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=key_cache,
value_cache=value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.seq_lens,
context_lens=attn_metadata.context_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output.view(num_tokens, self.hidden_size)

View File

@@ -19,13 +19,10 @@ import os
from typing import TYPE_CHECKING, Optional, Tuple
import torch
try:
import torch_npu # noqa: F401
except ImportError:
print("Failed to import torch_npu.")
from vllm.config import VllmConfig
import torch_npu # noqa: F401
import vllm.envs as envs
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import Platform, PlatformEnum
if TYPE_CHECKING:
@@ -35,18 +32,7 @@ else:
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
def _device_id_to_physical_device_id(device_id: int) -> int:
if "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty"
"string, which means Ascend NPU support is"
"disabled.")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
logger = init_logger(__name__)
class NPUPlatform(Platform):
@@ -74,8 +60,7 @@ class NPUPlatform(Platform):
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = _device_id_to_physical_device_id(device_id)
return torch.npu.get_device_name(physical_device_id)
return torch.npu.get_device_name(device_id)
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
@@ -103,23 +88,41 @@ class NPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
compilation_config = vllm_config.compilation_config
if compilation_config.level != CompilationLevel.NO_COMPILATION:
logger.warning(
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if vllm_config.scheduler_config.is_multi_step:
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
if vllm_config.scheduler_config.is_multi_step:
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
logger.warning(
"Prefix caching is not supported for V1 now, disable prefix caching"
)
cache_config.enable_prefix_caching = False
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
if use_v1:
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
if use_mla:
return "vllm_ascend.attention.AscendMLAAttentionBackend"
return "vllm_ascend.attention.AscendAttentionBackend"
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
return "vllm_ascend.attention.attention.AscendAttentionBackend"
@classmethod
def get_current_memory_usage(cls,
@@ -131,3 +134,7 @@ class NPUPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_ascend.communicator.NPUCommunicator"
@classmethod
def is_pin_memory_available(cls):
return True

View File

@@ -0,0 +1,837 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm.attention import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendMetadata)
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
logger = init_logger(__name__)
class NPUModelRunner:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
self.rejection_sampler = RejectionSampler()
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
)
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
dtype=torch.int64,
device=self.device)
self.mrope_positions_cpu = torch.zeros(
(3, self.max_num_tokens + 1),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len,
self.max_num_tokens),
dtype=np.int32)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.input_ids_np = self.input_ids_cpu.numpy()
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.input_positions_cpu = torch.arange(0,
self.max_num_tokens,
device="cpu")
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
The SamplingMetadata is updated and copied to the NPU if there is a
new/resumed/paused/finished request in the batch.
"""
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
removed_req_indices: List[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
# them from the persistent batch but keep their cached states since
# they will be scheduled again sometime in the future.
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
cached_req_ids = self.input_batch.req_id_to_index.keys()
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
req_index = self.input_batch.remove_request(req_id)
assert req_index is not None
removed_req_indices.append(req_index)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
# Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
req_state.block_ids.extend(req_data.new_block_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
start_index = (len(req_state.block_ids) -
len(req_data.new_block_ids))
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
if batch_changed:
self.input_batch.refresh_sampling_metadata()
def get_model(self) -> nn.Module:
return self.model
@staticmethod
def make_attention_mask(kv_dtype, kv_device, max_seq_len, seq_lens,
query_lens):
# for paged attention
atten_mask = np.zeros([0, max_seq_len])
for i, context_length in enumerate(seq_lens):
q_len = query_lens[i]
ones_len = context_length - q_len
ones = np.ones((q_len, ones_len), dtype=np.float16)
bias_cache = np.tril(
np.ones((q_len, max_seq_len - ones_len), dtype=np.float16))
bias_cache = np.concatenate((ones, bias_cache), axis=1)
mask_value = -10000
bias_cache[bias_cache == 0] = mask_value
bias_cache[bias_cache == 1] = 0
atten_mask = np.concatenate([atten_mask, bias_cache], axis=0)
atten_mask = torch.from_numpy(atten_mask).to(kv_dtype).to(kv_device)
return atten_mask
def _process_reqs(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
# check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
# prepare positions
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
positions = self.positions[:total_num_scheduled_tokens]
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs]
query_lens = torch.from_numpy(num_scheduled_tokens)
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True)
attn_mask = self.make_attention_mask(
self.vllm_config.model_config.dtype, self.device,
max(seq_lens, default=0), seq_lens, num_scheduled_tokens)
attn_metadata = AscendMetadata(
seq_lens=query_lens,
context_lens=seq_lens,
slot_mapping=slot_mapping,
block_tables=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
attn_mask=attn_mask,
)
# prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
input_ids = self.input_ids[:total_num_scheduled_tokens]
# Run forward pass
with set_forward_context(attn_metadata, self.vllm_config):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
)
return hidden_states[cu_num_tokens - 1]
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output)
hidden_states = self._process_reqs(scheduler_output,
intermediate_tensors)
logits = self.model.compute_logits(hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
# TODO(woosuk): Optimize this.
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=logprobs_lists,
prompt_logprobs_dict={},
)
return model_runner_output
def _profile_multimodal(self) -> None:
# TODO: handle encoder-decoder models once we support them.
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
if (not self.is_multimodal_model
or self.max_num_encoder_input_tokens <= 0
or self.encoder_cache_size <= 0):
return
max_tokens_by_modality_dict = (
MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(
self.model_config))
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
max_tokens_per_mm_item)
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
self.model_config)[dummy_data_modality]
# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req
max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)
logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_budget, max_num_mm_items, dummy_data_modality)
# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
if not isinstance(dummy_mm_data, MultiModalKwargs):
# TODO: Delete this check once input mapper is fully removed.
raise RuntimeError("Legacy input mapper is not supported in V1")
# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality,
item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number "
f"of multimodal data items: {max_num_mm_items}, got "
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode()
def _dummy_run(
self,
num_tokens: int,
) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.input_positions_cpu[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(input_ids=input_ids,
positions=positions.to(self.device),
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
self._profile_multimodal()
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
# assert self.lora_manager is not None, "LoRA is not enabled"
# TODO: call maybe_profile_with_lora()
dummy_kv_caches = [
torch.tensor((), dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
else:
logits = None
current_platform.synchronize()
del hidden_states, logits, dummy_kv_caches
self.encoder_cache.clear()
gc.collect()
def generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
raise ValueError("LoRA model is not supported on NPU now.")
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@@ -0,0 +1,246 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
from typing import Dict, List, Optional
import torch
import torch.distributed
import torch.nn as nn
import torch_npu
from vllm import envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
logger = init_logger(__name__)
class NPUWorker(WorkerBase):
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False):
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
experimental_config = torch_npu.profiler._ExperimentalConfig(
export_type=torch_npu.profiler.ExportType.Text,
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
msprof_tx=False,
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
l2_cache=False,
op_attr=False,
data_simplification=False,
record_op_args=False,
gc_detect_threshold=None,
)
self.profiler = torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU,
],
with_stack=True,
profile_memory=True,
with_modules=True,
experimental_config=experimental_config,
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir))
else:
self.profiler = None
def init_device(self):
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.empty_cache()
self.init_npu_memory = current_platform.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int:
kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
current_platform.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_npu_memory - free_npu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_npu_memory}, current free memory"
f" {free_npu_memory}. This happens when the NPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
gc.collect()
# TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified`
current_platform.empty_cache()
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
npu_kv_cache_bytes = max(usable_memory_size, 0)
logger.info(
f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}"
)
return int(npu_kv_cache_bytes)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
def load_model(self) -> None:
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
logger.warning("Graph capture is not supported on NPU.")
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate NPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config)
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, "hccl")
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)