### What this PR does / why we need it?
Add support for V1 Engine.
Please note that this is just the initial version, and there may be some
places need to be fixed or optimized in the future, feel free to leave
some comments to us.
### Does this PR introduce _any_ user-facing change?
To use V1 Engine on NPU device, you need to set the env variable shown
below:
```bash
export VLLM_USE_V1=1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
```
If you are using vllm for offline inferencing, you must add a `__main__`
guard like:
```bash
if __name__ == '__main__':
llm = vllm.LLM(...)
```
Find more details
[here](https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing).
### How was this patch tested?
I have tested the online serving with `Qwen2.5-7B-Instruct` using this
command:
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --max_model_len 26240
```
Query the model with input prompts:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-7B-Instruct",
"prompt": "The future of AI is",
"max_tokens": 7,
"temperature": 0
}'
```
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: didongli182 <didongli@huawei.com>
247 lines
10 KiB
Python
247 lines
10 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import gc
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from vllm import envs
|
|
from vllm.config import ParallelConfig, VllmConfig
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
|
init_distributed_environment,
|
|
set_custom_all_reduce)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor import set_random_seed
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
from vllm.v1.core.scheduler import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
KVCacheSpec)
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
from vllm.v1.utils import bind_kv_cache
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
|
|
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class NPUWorker(WorkerBase):
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
local_rank: int,
|
|
rank: int,
|
|
distributed_init_method: str,
|
|
is_driver_worker: bool = False):
|
|
# Register ops when worker init.
|
|
from vllm_ascend import ops # noqa: F401
|
|
|
|
super().__init__(vllm_config=vllm_config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
is_driver_worker=is_driver_worker)
|
|
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.load_config = vllm_config.load_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.device_config = vllm_config.device_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
self.observability_config = vllm_config.observability_config
|
|
|
|
if self.cache_config.cache_dtype == "auto":
|
|
self.cache_dtype = self.model_config.dtype
|
|
else:
|
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
self.cache_config.cache_dtype]
|
|
|
|
if self.model_config.trust_remote_code:
|
|
# note: lazy import to avoid importing torch before initializing
|
|
from vllm.utils import init_cached_hf_modules
|
|
init_cached_hf_modules()
|
|
# Torch profiler. Enabled and configured through env vars:
|
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
torch_profiler_trace_dir)
|
|
|
|
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
|
export_type=torch_npu.profiler.ExportType.Text,
|
|
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
|
|
msprof_tx=False,
|
|
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
|
l2_cache=False,
|
|
op_attr=False,
|
|
data_simplification=False,
|
|
record_op_args=False,
|
|
gc_detect_threshold=None,
|
|
)
|
|
|
|
self.profiler = torch_npu.profiler.profile(
|
|
activities=[
|
|
torch_npu.profiler.ProfilerActivity.CPU,
|
|
torch_npu.profiler.ProfilerActivity.NPU,
|
|
],
|
|
with_stack=True,
|
|
profile_memory=True,
|
|
with_modules=True,
|
|
experimental_config=experimental_config,
|
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
|
torch_profiler_trace_dir))
|
|
else:
|
|
self.profiler = None
|
|
|
|
def init_device(self):
|
|
if self.device_config.device.type == "npu":
|
|
self.device = torch.device(f"npu:{self.local_rank}")
|
|
current_platform.set_device(self.device)
|
|
|
|
current_platform.empty_cache()
|
|
self.init_npu_memory = current_platform.mem_get_info()[0]
|
|
else:
|
|
raise RuntimeError(
|
|
f"Not support device type: {self.device_config.device}")
|
|
# Initialize the distributed environment.
|
|
init_worker_distributed_environment(self.parallel_config, self.rank,
|
|
self.distributed_init_method,
|
|
self.local_rank)
|
|
# Set random seed.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
# Init ModelRunner here, so that we have access to self.device.
|
|
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
|
|
|
|
def determine_available_memory(self) -> int:
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
|
for layer_name, layer_spec in kv_cache_spec.items():
|
|
if isinstance(layer_spec, FullAttentionSpec):
|
|
dtype = layer_spec.dtype
|
|
|
|
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
|
# it by reference, rather by specializing on the value ``None``.
|
|
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
|
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
|
|
|
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
runner_kv_caches: List[torch.Tensor] = []
|
|
bind_kv_cache(
|
|
kv_caches,
|
|
self.vllm_config.compilation_config.static_forward_context,
|
|
runner_kv_caches)
|
|
|
|
# Profile the memory usage of the model and get the maximum number of
|
|
# cache blocks that can be allocated with the remaining free memory.
|
|
current_platform.empty_cache()
|
|
|
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
|
# of the model.
|
|
self.model_runner.profile_run()
|
|
|
|
# Calculate the number of blocks that can be allocated with the
|
|
# profiled peak memory.
|
|
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
|
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
|
# GPU did not change their memory usage during the profiling.
|
|
peak_memory = self.init_npu_memory - free_npu_memory
|
|
assert peak_memory > 0, (
|
|
"Error in memory profiling. "
|
|
f"Initial free memory {self.init_npu_memory}, current free memory"
|
|
f" {free_npu_memory}. This happens when the NPU memory was "
|
|
"not properly cleaned up before initializing the vLLM instance.")
|
|
|
|
gc.collect()
|
|
# TODO: don`t need impl this func after empty_cache in
|
|
# Worker.determine_num_available_blocks() unified`
|
|
current_platform.empty_cache()
|
|
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
|
|
npu_kv_cache_bytes = max(usable_memory_size, 0)
|
|
logger.info(
|
|
f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}"
|
|
)
|
|
return int(npu_kv_cache_bytes)
|
|
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> Optional[ModelRunnerOutput]:
|
|
output = self.model_runner.execute_model(scheduler_output)
|
|
return output if self.rank == 0 else None
|
|
|
|
def load_model(self) -> None:
|
|
self.model_runner.load_model()
|
|
|
|
def compile_or_warm_up_model(self) -> None:
|
|
if not self.model_config.enforce_eager:
|
|
logger.warning("Graph capture is not supported on NPU.")
|
|
# Reset the seed to ensure that the random state is not affected by
|
|
# the model initialization and profiling.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model_runner.get_model()
|
|
|
|
def get_kv_cache_spec(self) -> KVCacheSpec:
|
|
return self.model_runner.get_kv_cache_spec()
|
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""Allocate NPU KV cache with the specified kv_cache_config."""
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
|
|
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
|
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
kv_cache_config = kv_cache_configs[self.rank]
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
|
|
def profile(self, is_start: bool = True):
|
|
if self.profiler is None:
|
|
raise RuntimeError("Profiler is not enabled.")
|
|
if is_start:
|
|
self.profiler.start()
|
|
else:
|
|
self.profiler.stop()
|
|
|
|
|
|
def init_worker_distributed_environment(
|
|
parallel_config: ParallelConfig,
|
|
rank: int,
|
|
distributed_init_method: Optional[str] = None,
|
|
local_rank: int = -1) -> None:
|
|
"""Initialize the distributed environment."""
|
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
|
|
|
init_distributed_environment(parallel_config.world_size, rank,
|
|
distributed_init_method, local_rank, "hccl")
|
|
|
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
|
parallel_config.pipeline_parallel_size)
|