### What this PR does / why we need it?
Add support for V1 Engine.
Please note that this is just the initial version, and there may be some
places need to be fixed or optimized in the future, feel free to leave
some comments to us.
### Does this PR introduce _any_ user-facing change?
To use V1 Engine on NPU device, you need to set the env variable shown
below:
```bash
export VLLM_USE_V1=1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
```
If you are using vllm for offline inferencing, you must add a `__main__`
guard like:
```bash
if __name__ == '__main__':
llm = vllm.LLM(...)
```
Find more details
[here](https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing).
### How was this patch tested?
I have tested the online serving with `Qwen2.5-7B-Instruct` using this
command:
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --max_model_len 26240
```
Query the model with input prompts:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-7B-Instruct",
"prompt": "The future of AI is",
"max_tokens": 7,
"temperature": 0
}'
```
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: didongli182 <didongli@huawei.com>
141 lines
4.8 KiB
Python
141 lines
4.8 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
|
|
|
import torch
|
|
import torch_npu # noqa: F401
|
|
import vllm.envs as envs
|
|
from vllm.config import CompilationLevel, VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import Platform, PlatformEnum
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.utils import FlexibleArgumentParser
|
|
else:
|
|
FlexibleArgumentParser = None
|
|
|
|
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class NPUPlatform(Platform):
|
|
|
|
_enum = PlatformEnum.OOT
|
|
device_name: str = "npu"
|
|
device_type: str = "npu"
|
|
simple_compile_backend: str = "npu"
|
|
ray_device_key: str = "NPU"
|
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
|
dispatch_key: str = "PrivateUse1"
|
|
|
|
supported_quantization: list[str] = ["ascend"]
|
|
|
|
@classmethod
|
|
def pre_register_and_update(cls,
|
|
parser: Optional[FlexibleArgumentParser] = None
|
|
) -> None:
|
|
from vllm_ascend.quantization.quant_config import \
|
|
AscendQuantConfig # noqa: F401
|
|
|
|
@classmethod
|
|
def get_device_capability(cls, device_id: int = 0):
|
|
return None
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
return torch.npu.get_device_name(device_id)
|
|
|
|
@classmethod
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def inference_mode(cls):
|
|
return torch.inference_mode()
|
|
|
|
@classmethod
|
|
def set_device(cls, device: torch.device):
|
|
torch.npu.set_device(device)
|
|
|
|
@classmethod
|
|
def empty_cache(cls):
|
|
torch.npu.empty_cache()
|
|
|
|
@classmethod
|
|
def synchronize(cls):
|
|
torch.npu.synchronize()
|
|
|
|
@classmethod
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
|
return torch.npu.mem_get_info()
|
|
|
|
@classmethod
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
compilation_config = vllm_config.compilation_config
|
|
if compilation_config.level != CompilationLevel.NO_COMPILATION:
|
|
logger.warning(
|
|
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
|
|
compilation_config.level)
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
if parallel_config.worker_cls == "auto":
|
|
if envs.VLLM_USE_V1:
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
|
else:
|
|
if vllm_config.scheduler_config.is_multi_step:
|
|
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
|
else:
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
|
|
|
cache_config = vllm_config.cache_config
|
|
if cache_config and cache_config.block_size is None:
|
|
cache_config.block_size = 128
|
|
|
|
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
|
|
logger.warning(
|
|
"Prefix caching is not supported for V1 now, disable prefix caching"
|
|
)
|
|
cache_config.enable_prefix_caching = False
|
|
|
|
@classmethod
|
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
|
kv_cache_dtype, block_size, use_v1, use_mla):
|
|
if use_v1:
|
|
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
|
if use_mla:
|
|
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
|
|
return "vllm_ascend.attention.attention.AscendAttentionBackend"
|
|
|
|
@classmethod
|
|
def get_current_memory_usage(cls,
|
|
device: Optional[torch.types.Device] = None
|
|
) -> float:
|
|
torch.npu.reset_peak_memory_stats(device)
|
|
return torch.npu.max_memory_allocated(device)
|
|
|
|
@classmethod
|
|
def get_device_communicator_cls(cls) -> str:
|
|
return "vllm_ascend.communicator.NPUCommunicator"
|
|
|
|
@classmethod
|
|
def is_pin_memory_available(cls):
|
|
return True
|