[V1][Core] Add support for V1 Engine (#295)

### What this PR does / why we need it?
Add support for V1 Engine.

Please note that this is just the initial version, and there may be some
places need to be fixed or optimized in the future, feel free to leave
some comments to us.

### Does this PR introduce _any_ user-facing change?

To use V1 Engine on NPU device, you need to set the env variable shown
below:

```bash
export VLLM_USE_V1=1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
```

If you are using vllm for offline inferencing, you must add a `__main__`
guard like:

```bash
if __name__ == '__main__':

    llm = vllm.LLM(...)
```

Find more details
[here](https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing).

### How was this patch tested?
I have tested the online serving with `Qwen2.5-7B-Instruct` using this
command:

```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --max_model_len 26240
```

Query the model with input prompts:

```bash
curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "prompt": "The future of AI is",
        "max_tokens": 7,
        "temperature": 0
    }'
```

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: didongli182 <didongli@huawei.com>
This commit is contained in:
Shanshan Shen
2025-03-20 19:34:44 +08:00
committed by GitHub
parent 663dca7578
commit c06af8b2e0
13 changed files with 1385 additions and 38 deletions

View File

@@ -19,13 +19,10 @@ import os
from typing import TYPE_CHECKING, Optional, Tuple
import torch
try:
import torch_npu # noqa: F401
except ImportError:
print("Failed to import torch_npu.")
from vllm.config import VllmConfig
import torch_npu # noqa: F401
import vllm.envs as envs
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import Platform, PlatformEnum
if TYPE_CHECKING:
@@ -35,18 +32,7 @@ else:
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
def _device_id_to_physical_device_id(device_id: int) -> int:
if "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty"
"string, which means Ascend NPU support is"
"disabled.")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
logger = init_logger(__name__)
class NPUPlatform(Platform):
@@ -74,8 +60,7 @@ class NPUPlatform(Platform):
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = _device_id_to_physical_device_id(device_id)
return torch.npu.get_device_name(physical_device_id)
return torch.npu.get_device_name(device_id)
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
@@ -103,23 +88,41 @@ class NPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
compilation_config = vllm_config.compilation_config
if compilation_config.level != CompilationLevel.NO_COMPILATION:
logger.warning(
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if vllm_config.scheduler_config.is_multi_step:
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
if vllm_config.scheduler_config.is_multi_step:
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
logger.warning(
"Prefix caching is not supported for V1 now, disable prefix caching"
)
cache_config.enable_prefix_caching = False
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
if use_v1:
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
if use_mla:
return "vllm_ascend.attention.AscendMLAAttentionBackend"
return "vllm_ascend.attention.AscendAttentionBackend"
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
return "vllm_ascend.attention.attention.AscendAttentionBackend"
@classmethod
def get_current_memory_usage(cls,
@@ -131,3 +134,7 @@ class NPUPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_ascend.communicator.NPUCommunicator"
@classmethod
def is_pin_memory_available(cls):
return True