[Core] Init vllm-ascend (#3)
### What this PR does / why we need it?
vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on
the Ascend NPU.
This plugin is the recommended approach for supporting the Ascend
backend within the vLLM community. It adheres to the principles outlined
in the [RFC]: Hardware pluggable, providing a hardware-pluggable
interface that decouples the integration of the Ascend NPU with vLLM.
This patch also include changes to make CI work and use cache speed up
e2e test, including:
1. Change push (post merge ci) and pull_request (pr ci) trigger branch
to main
2. Make mypy work by ignore base_communicator and clear unused deps
3. Several improvements for vllm_ascend_test:
- use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins)
- switch `git clone` command to `action/checkout` to speedup checkout
and
- Enable sv for pytest for better info dump
- Remove network host to resole `docker: conflicting ontions: cannot
attach both user-defined and non-user-definednetwork-modes`, which is a
problem on docker 1.45 but not on 1.39.
4. Adapt MLA decode optimizations:
cabaf4eff3
### Does this PR introduce _any_ user-facing change?
Yes, init the PR.
### How was this patch tested?
- This is the first PR to make ascend NPU work on vLLM. All code is
tested on ascend with vLLM V0 Engine.
- CI passed
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: wangshuai09 <391746016@qq.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
21
vllm_ascend/__init__.py
Normal file
21
vllm_ascend/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def register():
|
||||
"""Register the NPU platform."""
|
||||
return "vllm_ascend.platform.NPUPlatform"
|
||||
678
vllm_ascend/attention.py
Normal file
678
vllm_ascend/attention.py
Normal file
@@ -0,0 +1,678 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/attention/backends
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except ImportError:
|
||||
print("Failed to import torch_npu.")
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
CommonMetadataBuilder,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.model_runner import ModelInputForNPUBuilder
|
||||
|
||||
SHARE_MASK_TRIL_PREFIX_CACHE = None
|
||||
SHARE_MASK_TRIL = None
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||
return AscendAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AscendMetadata"]:
|
||||
return AscendMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
dst_kv_cache: List[torch.Tensor],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
||||
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
||||
src_indices = src_to_dst[:, 0]
|
||||
dst_indices = src_to_dst[:, 1]
|
||||
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
src_indices = src_to_dists[:, 0]
|
||||
dst_indices = src_to_dists[:, 1]
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_caches = kv_cache[0]
|
||||
value_caches = kv_cache[1]
|
||||
key_caches[dst_indices] = key_caches[src_indices]
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
||||
return AscendMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
||||
return cls.get_builder_cls()(*args, **kwargs)
|
||||
|
||||
|
||||
class AscendPagedAttention(PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, slot_indices, key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, slot_indices, value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for Ascendbackend.
|
||||
* modified from XFormersbackend
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
||||
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
pse_shift: Optional[torch.Tensor] = None
|
||||
sparse_mode: int = 0
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# slot_mapping: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = AscendMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = AscendMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
|
||||
_metadata_cls = AscendMetadata
|
||||
|
||||
def compute_npu_slot_indices(self, is_profile_run, slot_indices, seq_id,
|
||||
seq_len, context_len, start_idx, block_size,
|
||||
block_tables, max_query_len):
|
||||
"""
|
||||
compute slot indices
|
||||
slot mapping in other backend of vllm stores slot indices,
|
||||
which are indicates by `block_number * block_size + block_offset`
|
||||
In Ascend backend, slot mapping stores [block_number, block_offset].
|
||||
To distinguish this, slot_indices is used in this func
|
||||
"""
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_indices.extend([[PAD_SLOT_ID, 0]] * seq_len)
|
||||
return
|
||||
# Mask the [0, start_idx) tokens of the prompt with
|
||||
# [PAD_SLOT_ID, 0], where start_idx is max(0, seq_len -
|
||||
# sliding_window). For example, if the prompt len is 10,
|
||||
# sliding window is 8, and block size is 4, the first two
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
padding_mask_len = max(0, start_idx - context_len)
|
||||
slot_indices.extend([[PAD_SLOT_ID, 0]] * padding_mask_len)
|
||||
|
||||
range_start = max(start_idx, context_len)
|
||||
range_end = seq_len
|
||||
numel = range_end - range_start
|
||||
block_table = block_tables[seq_id]
|
||||
|
||||
for i in range(range_start, range_end):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot_indices.append([block_number, block_offset])
|
||||
slot_indices.extend([[PAD_SLOT_ID, 0]] * (max_query_len - numel))
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
max_query_len = max(
|
||||
max(data.query_lens)
|
||||
for data in self.input_builder.inter_data_list)
|
||||
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table: List[int] = []
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
if block_tables is not None:
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
|
||||
self.compute_npu_slot_indices(is_profile_run, self.slot_mapping,
|
||||
seq_id, seq_len, context_len,
|
||||
start_idx, self.block_size,
|
||||
inter_data.block_tables,
|
||||
max_query_len)
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
num_tokens = batch_size * seq_len
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len * num_heads * head_size]
|
||||
"""
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# view q k v to BSH
|
||||
num_tokens = query.shape[0]
|
||||
|
||||
if kv_cache is not None and len(kv_cache) >= 2:
|
||||
slot_indices = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
AscendPagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_indices,
|
||||
)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if attn_metadata.attn_mask is None:
|
||||
if num_tokens > 16384:
|
||||
attn_metadata.sparse_mode = 2
|
||||
attention_mask = gen_input_mask(
|
||||
attn_metadata.max_prefill_seq_len, self.sliding_window,
|
||||
num_tokens)
|
||||
attn_metadata.attn_mask = attention_mask
|
||||
|
||||
if (self.alibi_slopes is not None
|
||||
and attn_metadata.pse_shift is None):
|
||||
attn_metadata.pse_shift = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
self.num_kv_heads,
|
||||
dtype=query.dtype,
|
||||
seq_len=attn_metadata.max_prefill_seq_len,
|
||||
batch_size=num_tokens,
|
||||
)
|
||||
|
||||
if (len(kv_cache) == 0 or attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
|
||||
# shape of q/k/v [B,S*H] --> [B,S,N,D]
|
||||
query = query.view(-1, max_seq_len, self.num_heads,
|
||||
self.head_size).transpose(1, 2)
|
||||
key = key.view(-1, max_seq_len, self.num_kv_heads,
|
||||
self.head_size).transpose(1, 2)
|
||||
value = value.view(-1, max_seq_len, self.num_kv_heads,
|
||||
self.head_size).transpose(1, 2)
|
||||
# FA for prefill phase
|
||||
output = torch_npu.npu_prompt_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
pse_shift=attn_metadata.pse_shift,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=1 / math.sqrt(self.head_size),
|
||||
input_layout="BNSD",
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
pre_tokens=65535,
|
||||
next_tokens=0,
|
||||
sparse_mode=attn_metadata.sparse_mode,
|
||||
)
|
||||
# reshape to [B,H]
|
||||
output = output.transpose(1, 2).reshape(
|
||||
num_tokens, self.num_heads * self.head_size)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert kv_cache is not None
|
||||
query = query.view(query.shape[0], -1,
|
||||
self.num_heads * self.head_size)
|
||||
output = torch.zeros(query.shape,
|
||||
device="npu",
|
||||
dtype=query.dtype)
|
||||
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
|
||||
# support only when `S == 1`, OPTIMIZE ME when prefix caching
|
||||
# is supported in torch-npu ops.
|
||||
for i in range(query.shape[0]):
|
||||
# FA for prefill phase
|
||||
output[i] = torch_npu.npu_incre_flash_attention(
|
||||
query[i].unsqueeze(0),
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
scale_value=self.scale,
|
||||
input_layout="BSH",
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=key_cache.
|
||||
shape[1], # max val of block_size == 512
|
||||
actual_seq_lengths=attn_metadata.seq_lens,
|
||||
)
|
||||
# [B,S,H] --> [B,H]
|
||||
output = output.squeeze(1)
|
||||
|
||||
elif attn_metadata.decode_metadata:
|
||||
# FA for decoding phase
|
||||
assert kv_cache is not None
|
||||
# shape of query [B,S*H] --> [B,S,H]
|
||||
query = query.view(
|
||||
-1,
|
||||
1,
|
||||
self.head_size * self.num_heads,
|
||||
)
|
||||
output = torch_npu.npu_incre_flash_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
scale_value=self.scale,
|
||||
input_layout="BSH",
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=key_cache.shape[1], # max val of block_size == 512
|
||||
actual_seq_lengths=attn_metadata.seq_lens,
|
||||
)
|
||||
|
||||
# [B,S,H] --> [B,H]
|
||||
output = output.squeeze(1)
|
||||
return output
|
||||
|
||||
|
||||
def gen_input_mask(seq_len, sliding_window, len):
|
||||
"""
|
||||
Generating lower triangular matrix
|
||||
"""
|
||||
if len > 16384:
|
||||
# improve computing performance on NPU when input tokens are huge
|
||||
global SHARE_MASK_TRIL_PREFIX_CACHE
|
||||
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
|
||||
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
|
||||
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
|
||||
diagonal=1,
|
||||
)
|
||||
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
|
||||
else:
|
||||
global SHARE_MASK_TRIL
|
||||
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
|
||||
SHARE_MASK_TRIL = ~torch.tril(
|
||||
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
|
||||
|
||||
attention_mask = SHARE_MASK_TRIL
|
||||
if sliding_window is not None:
|
||||
attention_mask = ~attention_mask
|
||||
attention_mask = torch.triu(attention_mask,
|
||||
diagonal=1 - sliding_window)
|
||||
attention_mask = ~attention_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_len: int,
|
||||
batch_size: int,
|
||||
):
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1,
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
|
||||
return bias
|
||||
28
vllm_ascend/communicator.py
Normal file
28
vllm_ascend/communicator.py
Normal file
@@ -0,0 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_communicator import \
|
||||
CommunicatorBase
|
||||
|
||||
|
||||
class NPUCommunicator(CommunicatorBase):
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
620
vllm_ascend/model_runner.py
Normal file
620
vllm_ascend/model_runner.py
Normal file
@@ -0,0 +1,620 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.utils import flatten_2d_lists, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (ModelInputForGPU,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
|
||||
class ModelInputForNPUBuilder(ModelInputForGPUBuilder):
|
||||
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
||||
|
||||
# Note: ideally we would be using a dataclass(kw_only=True)
|
||||
# here, so that this can be subclassed easily,
|
||||
# but kw_only is not supported in python<3.10.
|
||||
def build(self) -> ModelInputForGPU:
|
||||
"""Finalize the builder intermediate data and
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = [
|
||||
flatten_2d_lists(inter_data.input_tokens)
|
||||
for inter_data in self.inter_data_list
|
||||
]
|
||||
if not input_tokens:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
|
||||
mrope_input_positions: Optional[List[List[int]]] = None
|
||||
if any(inter_data.mrope_input_positions is not None
|
||||
for inter_data in self.inter_data_list):
|
||||
mrope_input_positions = [[] for _ in range(3)]
|
||||
# calculate max position length for padding
|
||||
input_position_lens = [
|
||||
len(inter_data.input_positions[0])
|
||||
for inter_data in self.inter_data_list
|
||||
]
|
||||
max_pos_len = max(input_position_lens)
|
||||
|
||||
for idx in range(3):
|
||||
for inter_data in self.inter_data_list:
|
||||
msections = inter_data.mrope_input_positions
|
||||
if msections is None:
|
||||
for _seq_input_positions in inter_data.input_positions:
|
||||
# zero pad
|
||||
_seq_input_positions.extend(
|
||||
[0] *
|
||||
(max_pos_len - len(_seq_input_positions)))
|
||||
mrope_input_positions[idx].extend(
|
||||
_seq_input_positions)
|
||||
else:
|
||||
for _seq_mrope_input_positions in msections:
|
||||
# zero pad
|
||||
_seq_mrope_input_positions[idx].extend(
|
||||
[0] * (max_pos_len -
|
||||
len(_seq_mrope_input_positions[idx])))
|
||||
mrope_input_positions[idx].extend(
|
||||
_seq_mrope_input_positions[idx])
|
||||
input_positions = None
|
||||
else:
|
||||
input_positions = [
|
||||
flatten_2d_lists(inter_data.input_positions)
|
||||
for inter_data in self.inter_data_list
|
||||
]
|
||||
|
||||
seq_lens = []
|
||||
max_decode_seq_len = 0
|
||||
for inter_data in self.inter_data_list:
|
||||
seq_lens.extend(inter_data.seq_lens)
|
||||
if not inter_data.is_prompt:
|
||||
max_decode_seq_len = max(max_decode_seq_len,
|
||||
max(inter_data.seq_lens))
|
||||
query_lens = flatten_2d_lists(
|
||||
[inter_data.query_lens for inter_data in self.inter_data_list])
|
||||
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||||
# that manages the cache by itself.
|
||||
request_ids_to_seq_ids = {
|
||||
data.request_id: data.seq_ids
|
||||
for data in self.inter_data_list
|
||||
}
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
cuda_graph_pad_size = -1
|
||||
|
||||
if self.inter_data_list[0].is_prompt:
|
||||
input_tokens_tensor = make_tensor_with_pad(
|
||||
input_tokens, 0, dtype=torch.int, device=self.runner.device)
|
||||
input_tokens_tensor = torch.flatten(input_tokens_tensor)
|
||||
if mrope_input_positions is not None:
|
||||
mrope_input_positions_tensor = make_tensor_with_pad(
|
||||
mrope_input_positions,
|
||||
0,
|
||||
dtype=torch.int,
|
||||
device=self.runner.device)
|
||||
input_positions_tensor = torch.tensor(
|
||||
mrope_input_positions_tensor,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
else:
|
||||
input_positions_tensor = make_tensor_with_pad(
|
||||
input_positions,
|
||||
0,
|
||||
dtype=torch.int,
|
||||
device=self.runner.device)
|
||||
input_positions_tensor = torch.flatten(input_positions_tensor)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = len(seq_lens) * [max_seq_len]
|
||||
else:
|
||||
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
if mrope_input_positions is not None:
|
||||
input_positions_tensor = torch.tensor(
|
||||
mrope_input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
else:
|
||||
input_positions_tensor = torch.tensor(
|
||||
flatten_2d_lists(input_positions),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
|
||||
# Sequence and query lengths.
|
||||
seq_lens.extend([1] * cuda_graph_pad_size)
|
||||
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
lora_requests = set(r for data in self.inter_data_list
|
||||
for r in data.lora_requests)
|
||||
lora_index_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_index_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
lora_prompt_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_mapping = LoRAMapping(
|
||||
**dict(index_mapping=lora_index_mapping,
|
||||
prompt_mapping=lora_prompt_mapping,
|
||||
is_prefill=not self.decode_only))
|
||||
|
||||
# Prompt adapter data.
|
||||
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
prompt_adapter_mapping = None
|
||||
if self.enable_prompt_adapter:
|
||||
prompt_adapter_requests = set(
|
||||
data.prompt_adapter_request for data in self.inter_data_list
|
||||
if data.prompt_adapter_request is not None)
|
||||
prompt_adapter_index_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_index_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_prompt_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
prompt_adapter_index_mapping,
|
||||
prompt_adapter_prompt_mapping,
|
||||
)
|
||||
|
||||
# Multi-modal data.
|
||||
multi_modal_kwargs_list = [
|
||||
data.multi_modal_kwargs for data in self.inter_data_list
|
||||
if data.multi_modal_kwargs is not None
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
prompt_adapter_mapping=prompt_adapter_mapping,
|
||||
prompt_adapter_requests=prompt_adapter_requests)
|
||||
|
||||
class InterDataForSeqGroup:
|
||||
"""Intermediate data for the current sequence group."""
|
||||
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.token_types[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
self.query_lens[0] = 0 # type: ignore
|
||||
self.context_lens[0] = 0 # type: ignore
|
||||
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||||
self.lora_index_mapping.clear() # type: ignore
|
||||
self.lora_prompt_mapping.clear() # type: ignore
|
||||
self.lora_requests.clear() # type: ignore
|
||||
self.prompt_adapter_index_mapping.clear() # type: ignore
|
||||
self.prompt_adapter_prompt_mapping.clear() # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# From sequence group metadata.
|
||||
request_id: str,
|
||||
seq_ids: List[int],
|
||||
is_prompt: bool,
|
||||
block_tables: Optional[Dict[int, List[int]]],
|
||||
computed_block_nums: List[int],
|
||||
n_seqs: int = 0,
|
||||
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
token_types: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
seq_lens: Optional[List[int]] = None,
|
||||
# The original sequence length (before applying sliding window).
|
||||
# This is used to compute slot mapping.
|
||||
orig_seq_lens: Optional[List[int]] = None,
|
||||
# The query length.
|
||||
query_lens: Optional[List[int]] = None,
|
||||
# The number of tokens that are already computed.
|
||||
context_lens: Optional[List[int]] = None,
|
||||
# The current sliding window block.
|
||||
curr_sliding_window_blocks: Optional[List[int]] = None,
|
||||
|
||||
# LoRA inputs.
|
||||
lora_index_mapping: Optional[List[List[int]]] = None,
|
||||
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||||
lora_requests: Optional[Set[LoRARequest]] = None,
|
||||
|
||||
# Prompt adapter inputs.
|
||||
prompt_adapter_index_mapping: Optional[List[int]] = None,
|
||||
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
|
||||
# Multi-modal inputs.
|
||||
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
||||
multi_modal_placeholder_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap]] = None,
|
||||
|
||||
# Whether the prefix cache is hit (prefill only).
|
||||
prefix_cache_hit: bool = False,
|
||||
reinit: bool = False,
|
||||
reinit_use_defaults: bool = False,
|
||||
encoder_seq_len: int = 0,
|
||||
):
|
||||
if reinit:
|
||||
assert len(self.seq_ids) == len(seq_ids) # type: ignore
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
self.seq_ids[i] = seq_id # type: ignore
|
||||
else:
|
||||
self.seq_ids = seq_ids
|
||||
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
self.block_tables = block_tables
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.n_seqs = n_seqs
|
||||
self.encoder_seq_len = encoder_seq_len
|
||||
|
||||
if reinit:
|
||||
if len(self.seq_ids) == 1 and reinit_use_defaults:
|
||||
self.simple_reinit()
|
||||
else:
|
||||
if input_tokens:
|
||||
self.input_tokens = input_tokens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_tokens[seq_id].clear()
|
||||
|
||||
if input_positions:
|
||||
self.input_positions = input_positions
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_positions[seq_id].clear()
|
||||
|
||||
if token_types:
|
||||
self.token_types = token_types
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.token_types[seq_id].clear()
|
||||
|
||||
self.mrope_input_positions = None
|
||||
|
||||
if seq_lens:
|
||||
self.seq_lens = seq_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.seq_lens[seq_id] = 0
|
||||
|
||||
if orig_seq_lens:
|
||||
self.orig_seq_lens = orig_seq_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.orig_seq_lens[seq_id] = 0
|
||||
|
||||
if query_lens:
|
||||
self.query_lens = query_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.query_lens[seq_id] = 0
|
||||
|
||||
if context_lens:
|
||||
self.context_lens = context_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.context_lens[seq_id] = 0
|
||||
|
||||
if curr_sliding_window_blocks:
|
||||
self.curr_sliding_window_blocks = \
|
||||
curr_sliding_window_blocks
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.curr_sliding_window_blocks[seq_id] = 0
|
||||
|
||||
if lora_index_mapping:
|
||||
self.lora_index_mapping = lora_index_mapping
|
||||
else:
|
||||
self.lora_index_mapping.clear()
|
||||
|
||||
if lora_prompt_mapping:
|
||||
self.lora_prompt_mapping = lora_prompt_mapping
|
||||
else:
|
||||
self.lora_prompt_mapping.clear()
|
||||
|
||||
if lora_requests:
|
||||
self.lora_requests = lora_requests
|
||||
else:
|
||||
self.lora_requests.clear()
|
||||
|
||||
if prompt_adapter_index_mapping:
|
||||
self.prompt_adapter_index_mapping = \
|
||||
prompt_adapter_index_mapping
|
||||
else:
|
||||
self.prompt_adapter_index_mapping.clear()
|
||||
|
||||
if prompt_adapter_prompt_mapping:
|
||||
self.prompt_adapter_prompt_mapping = \
|
||||
prompt_adapter_prompt_mapping
|
||||
else:
|
||||
self.prompt_adapter_prompt_mapping.clear()
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
self.token_types = token_types or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
self.context_lens = context_lens or []
|
||||
self.curr_sliding_window_blocks = \
|
||||
curr_sliding_window_blocks or []
|
||||
|
||||
self.lora_index_mapping = lora_index_mapping or []
|
||||
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||
self.lora_requests = lora_requests or set()
|
||||
|
||||
self.prompt_adapter_index_mapping = (
|
||||
prompt_adapter_index_mapping or [])
|
||||
self.prompt_adapter_prompt_mapping = (
|
||||
prompt_adapter_prompt_mapping or [])
|
||||
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.multi_modal_kwargs = multi_modal_kwargs
|
||||
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
||||
self.prefix_cache_hit = prefix_cache_hit
|
||||
|
||||
self.n_seqs = len(self.seq_ids)
|
||||
|
||||
if not reinit:
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
self.n_seqs = len(self.seq_ids)
|
||||
|
||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||
self.token_types = [[] for _ in range(self.n_seqs)]
|
||||
self.mrope_input_positions = None
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
self.query_lens = [0] * self.n_seqs
|
||||
self.context_lens = [0] * self.n_seqs
|
||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||
|
||||
self.lora_index_mapping = []
|
||||
self.lora_prompt_mapping = []
|
||||
|
||||
|
||||
class NPUModelRunner(ModelRunner):
|
||||
"""
|
||||
NPU model runner with sampling step.
|
||||
"""
|
||||
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
model_input = \
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
return model_input
|
||||
|
||||
@current_platform.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
# This represents the maximum number of different requests
|
||||
# that will have unique loras, an therefore the max amount of memory
|
||||
# consumption create dummy lora request copies from the lora request
|
||||
# passed in, which contains a lora from the lora warmup path.
|
||||
dummy_lora_requests: List[LoRARequest] = []
|
||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||
if self.lora_config:
|
||||
assert self.lora_manager is not None
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
for idx in range(self.lora_config.max_loras):
|
||||
lora_id = idx + 1
|
||||
dummy_lora_request = LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||
rank=LORA_WARMUP_RANK)
|
||||
dummy_lora_requests.append(dummy_lora_request)
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
if max_num_seqs < 1:
|
||||
expr = (f"min({max_num_seqs_orig}, "
|
||||
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||
logger.warning(
|
||||
"Computed max_num_seqs (%s) to be less than 1. "
|
||||
"Setting it to the minimum value of 1.", expr)
|
||||
max_num_seqs = 1
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: dummy_data.seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||
if dummy_lora_requests_per_seq else None,
|
||||
multi_modal_data=dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
# it is important to create tensors inside the loop, rather than
|
||||
# multiplying the list, to avoid Dynamo from treating them as
|
||||
# tensor aliasing.
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
current_platform.synchronize()
|
||||
return
|
||||
|
||||
@current_platform.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""NPU graph capture a model.
|
||||
TODO: not support now
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||
The result tensors and data structure also batches input in prefill
|
||||
-> decode order. For example,
|
||||
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
if get_pp_group().is_last_rank:
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
generators,
|
||||
self.sampling_metadata_cache,
|
||||
# TODO (cmq): enable this after supported in vllm
|
||||
# pad_for_invariant_seq_len=True,
|
||||
)
|
||||
else:
|
||||
sampling_metadata = None
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prompt=is_prompt,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
18
vllm_ascend/ops/__init__.py
Normal file
18
vllm_ascend/ops/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
40
vllm_ascend/ops/layernorm.py
Normal file
40
vllm_ascend/ops/layernorm.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x
|
||||
|
||||
|
||||
RMSNorm.forward_oot = forward_oot
|
||||
115
vllm_ascend/platform.py
Normal file
115
vllm_ascend/platform.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except ImportError:
|
||||
print("Failed to import torch_npu.")
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
|
||||
def _device_id_to_physical_device_id(device_id: int) -> int:
|
||||
if "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
|
||||
device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")
|
||||
if device_ids == [""]:
|
||||
raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty"
|
||||
"string, which means Ascend NPU support is"
|
||||
"disabled.")
|
||||
physical_device_id = device_ids[device_id]
|
||||
return int(physical_device_id)
|
||||
else:
|
||||
return device_id
|
||||
|
||||
|
||||
class NPUPlatform(Platform):
|
||||
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name: str = "npu"
|
||||
device_type: str = "npu"
|
||||
simple_compile_backend: str = "npu"
|
||||
ray_device_key: str = "NPU"
|
||||
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = _device_id_to_physical_device_id(device_id)
|
||||
return torch.npu.get_device_name(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.inference_mode()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device):
|
||||
torch.npu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def synchronize(cls):
|
||||
torch.npu.synchronize()
|
||||
|
||||
@classmethod
|
||||
def mem_get_info(cls) -> Tuple[int, int]:
|
||||
return torch.npu.mem_get_info()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
# Register ops when setup.
|
||||
from vllm_ascend import ops # noqa: F401
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
return "vllm_ascend.attention.AscendAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.npu.reset_peak_memory_stats(device)
|
||||
return torch.npu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm_ascend.communicator.NPUCommunicator"
|
||||
481
vllm_ascend/worker.py
Normal file
481
vllm_ascend/worker.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import gc
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
from vllm_ascend.model_runner import NPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a NPU.
|
||||
Each worker is associated with a single NPU. The worker is responsible for
|
||||
maintaining the KV cache and executing the model on the NPU. In case of
|
||||
distributed inference, each worker is assigned a partition of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
# distribute related config
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
if is_driver_worker:
|
||||
assert rank % self.parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
||||
if model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = PoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: ModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=self.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
export_type=torch_npu.profiler.ExportType.Text,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
|
||||
msprof_tx=False,
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
||||
l2_cache=False,
|
||||
op_attr=False,
|
||||
data_simplification=False,
|
||||
record_op_args=False,
|
||||
gc_detect_threshold=None,
|
||||
)
|
||||
|
||||
self.profiler = torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU,
|
||||
],
|
||||
with_stack=True,
|
||||
profile_memory=True,
|
||||
with_modules=True,
|
||||
experimental_config=experimental_config,
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "npu":
|
||||
# # This env var set by Ray causes exceptions with graph building.
|
||||
# os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"npu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
current_platform.empty_cache()
|
||||
self.init_npu_memory = current_platform.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_runner.save_sharded_state(
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
@current_platform.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of NPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
.. tip::
|
||||
You may limit the usage of NPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
current_platform.empty_cache()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_npu_memory - free_npu_memory
|
||||
assert peak_memory > 0, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_npu_memory}, current free memory"
|
||||
f" {free_npu_memory}. This happens when the NPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_npu_blocks = int(
|
||||
(total_npu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_npu_blocks = max(num_npu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
if self.model_runner.lora_manager:
|
||||
self.model_runner.remove_all_loras()
|
||||
gc.collect()
|
||||
# TODO: don`t need impl this func after empty_cache in
|
||||
# Worker.determine_num_available_blocks() unified`
|
||||
current_platform.empty_cache()
|
||||
return num_npu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Allocate NPU and CPU KV cache with the specified number of blocks.
|
||||
"""
|
||||
raise_if_cache_size_invalid(num_gpu_blocks,
|
||||
self.cache_config.block_size,
|
||||
self.cache_config.is_attention_free,
|
||||
self.model_config.max_model_len)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self._init_cache_engine()
|
||||
self._warm_up_model()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
assert self.cache_config.num_gpu_blocks is not None
|
||||
self.cache_engine = [
|
||||
CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.gpu_cache)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# model capture is not supported, thus we just set seed here.
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.gpu_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_steps = execute_model_req.num_steps
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||
# blocks to copy are in the same device, and `blocks_to_copy`
|
||||
# can be used directly within cuda kernels.
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_steps=num_steps,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_in(
|
||||
worker_input.blocks_to_swap_in)
|
||||
if (worker_input.blocks_to_swap_out is not None
|
||||
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_out(
|
||||
worker_input.blocks_to_swap_out)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||||
|
||||
def _get_cached_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]],
|
||||
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
|
||||
"""Return a list of cached Sequence Group Metadata after updating its
|
||||
state.
|
||||
|
||||
It is used because scheduler only sends delta to workers to reduce
|
||||
the data payload size. The function also cleans up cache based on
|
||||
a given `finished_request_ids`.
|
||||
"""
|
||||
new_seq_group_metadata_list = []
|
||||
for metadata_or_delta in seq_group_metadata_list:
|
||||
request_id = metadata_or_delta.request_id
|
||||
if request_id not in self._seq_group_metadata_cache:
|
||||
# The first prefill.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[request_id] = metadata_or_delta
|
||||
else:
|
||||
# The first prefill is already cached.
|
||||
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
|
||||
self._seq_group_metadata_cache[request_id].apply_delta(
|
||||
metadata_or_delta)
|
||||
else:
|
||||
# If metadata snapshot is sent again, it is
|
||||
# preempted. Reset the cache because we need to start
|
||||
# from scratch.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[
|
||||
request_id] = metadata_or_delta
|
||||
|
||||
new_seq_group_metadata_list.append(
|
||||
self._seq_group_metadata_cache[request_id])
|
||||
|
||||
# Clean up finished ids
|
||||
for finished_id in finished_request_ids:
|
||||
del self._seq_group_metadata_cache[finished_id]
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if execute_model_req is not None:
|
||||
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.finished_requests_ids)
|
||||
|
||||
execute_model_req.seq_group_metadata_list = (
|
||||
new_seq_group_metadata_list)
|
||||
output = super()._execute_model_spmd(execute_model_req,
|
||||
intermediate_tensors)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRA is not implemented for NPU backend currently.")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRA is not implemented for NPU backend currently.")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRA is not implemented for NPU backend currently.")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"LoRA is not implemented for NPU backend currently.")
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for NPU backend currently.")
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for NPU backend currently.")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for NPU backend currently.")
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for NPU backend currently.")
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Get the size of the KV cache block size in bytes.
|
||||
"""
|
||||
return CacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "hccl") -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||||
max_model_len) -> None:
|
||||
if is_attention_free and num_gpu_blocks != 0:
|
||||
raise ValueError("No memory should be allocated for the cache blocks "
|
||||
f"for an attention-free model, but {num_gpu_blocks}"
|
||||
"blocks are allocated.")
|
||||
if not is_attention_free and num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = block_size * num_gpu_blocks
|
||||
if not is_attention_free and max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
Reference in New Issue
Block a user