<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR is used for resolved [issue 1147](https://github.com/vllm-project/vllm-ascend/issues/1147) 1. Move fused_moe code into one file `fused_moe.py`. 2. Integrate branch conditions into function `get_fused_moe_state`. <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? 1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this env is useless, we can make judgments based on the current scenario without this env, it will only increase complexity. 2. This PR has removed the env `USING_LCCL_COM`, because this env has already expired. 3. `additional_config.expert_tensor_parallel_size` has already expired, and now we also use parameter `enable_expert_parallel`, consistent with the vLLM. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zzzzwwjj <1183291235@qq.com>
312 lines
13 KiB
Python
312 lines
13 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
|
|
#
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
|
from vllm import envs
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
|
init_distributed_environment,
|
|
set_custom_all_reduce)
|
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
|
from vllm.logger import logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.model_executor import set_random_seed
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
|
|
|
from vllm_ascend.ascend_config import init_ascend_config
|
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.utils import try_register_lib
|
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
|
|
|
|
class NPUWorker(WorkerBase):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
local_rank: int,
|
|
rank: int,
|
|
distributed_init_method: str,
|
|
is_driver_worker: bool = False,
|
|
# Additional parameters for compatibility with vllm
|
|
**kwargs):
|
|
"""Initialize the worker for Ascend."""
|
|
# register patch for vllm
|
|
from vllm_ascend.utils import adapt_patch
|
|
adapt_patch()
|
|
# Register ops when worker init.
|
|
from vllm_ascend import ops
|
|
ops.register_dummy_fusion_op()
|
|
_register_atb_extensions()
|
|
# init ascend config
|
|
init_ascend_config(vllm_config)
|
|
|
|
super().__init__(vllm_config=vllm_config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
is_driver_worker=is_driver_worker)
|
|
|
|
# Try to import mindie_turbo to accelerate vLLM inference.
|
|
try_register_lib(
|
|
"mindie_turbo",
|
|
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
|
)
|
|
if self.cache_config.cache_dtype == "auto":
|
|
self.cache_dtype = self.model_config.dtype
|
|
else:
|
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
self.cache_config.cache_dtype]
|
|
|
|
if self.model_config.trust_remote_code:
|
|
# note: lazy import to avoid importing torch before initializing
|
|
from vllm.utils import init_cached_hf_modules
|
|
init_cached_hf_modules()
|
|
|
|
self.profiler = self._init_profiler()
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
NPUPlatform.set_device(self.device)
|
|
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
|
allocator = CaMemAllocator.get_instance()
|
|
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
|
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
|
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
|
used_bytes = total - free_bytes_after_sleep
|
|
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
|
logger.info(
|
|
"Sleep mode freed %.2f GiB memory, "
|
|
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
|
used_bytes / GiB_bytes)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
allocator = CaMemAllocator.get_instance()
|
|
allocator.wake_up(tags=tags)
|
|
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
num_cpu_blocks: int) -> None:
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
def init_device(self):
|
|
if self.device_config.device.type == "npu":
|
|
self.device = torch.device(f"npu:{self.local_rank}")
|
|
NPUPlatform.set_device(self.device)
|
|
NPUPlatform.empty_cache()
|
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
|
else:
|
|
info = f"Not support device type: {self.device_config.device}"
|
|
logger.error(info)
|
|
raise RuntimeError(info)
|
|
# Initialize the distributed environment.
|
|
self._init_worker_distributed_environment()
|
|
# Set random seed.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
# Init ModelRunner here, so that we have access to self.device.
|
|
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
|
|
|
|
def determine_available_memory(self) -> int:
|
|
# Profile the memory usage of the model and get the maximum number of
|
|
# cache blocks that can be allocated with the remaining free memory.
|
|
NPUPlatform.clear_npu_memory()
|
|
|
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
|
# of the model.
|
|
_, total_npu_memory = NPUPlatform.mem_get_info()
|
|
self.model_runner.profile_run()
|
|
|
|
# Calculate the number of blocks that can be allocated with the
|
|
# profiled peak memory.
|
|
free_npu_memory, _ = NPUPlatform.mem_get_info()
|
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
|
# GPU did not change their memory usage during the profiling.
|
|
assert self.init_npu_memory > free_npu_memory, (
|
|
"Error in memory profiling. "
|
|
f"Initial free memory {self.init_npu_memory}, current free memory"
|
|
f" {free_npu_memory}. This happens when the NPU memory was "
|
|
"not properly cleaned up before initializing the vLLM instance.")
|
|
|
|
# Get the peak memory allocation recorded by torch
|
|
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
|
|
# TODO: don`t need impl this func after empty_cache in
|
|
# Worker.determine_num_available_blocks() unified`
|
|
NPUPlatform.empty_cache()
|
|
torch_allocated_bytes = torch_npu.npu.memory_stats(
|
|
)["allocated_bytes.all.current"]
|
|
total_allocated_bytes = torch_npu.npu.mem_get_info(
|
|
)[1] - torch_npu.npu.mem_get_info()[0]
|
|
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
|
if non_torch_allocations > 0:
|
|
peak_memory += non_torch_allocations
|
|
available_kv_cache_memory = (
|
|
total_npu_memory * self.cache_config.gpu_memory_utilization -
|
|
peak_memory)
|
|
return int(available_kv_cache_memory)
|
|
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> Optional[ModelRunnerOutput]:
|
|
output = self.model_runner.execute_model(scheduler_output)
|
|
return output if self.is_driver_worker else None
|
|
|
|
def load_model(self) -> None:
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
allocator = CaMemAllocator.get_instance()
|
|
assert allocator.get_current_usage() == 0, (
|
|
"Sleep mode can only be "
|
|
"used for one instance per process.")
|
|
context = allocator.use_memory_pool(tag="weights")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext() # type: ignore
|
|
with context:
|
|
self.model_runner.load_model()
|
|
|
|
def compile_or_warm_up_model(self) -> None:
|
|
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
|
if not self.model_config.enforce_eager:
|
|
warmup_sizes = [
|
|
x for x in warmup_sizes if x not in
|
|
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
|
]
|
|
for size in sorted(warmup_sizes, reverse=True):
|
|
logger.info("Compile and warming up model for size %d", size)
|
|
self.model_runner._dummy_run(size)
|
|
if not self.model_config.enforce_eager:
|
|
self.model_runner.capture_model()
|
|
# Reset the seed to ensure that the random state is not affected by
|
|
# the model initialization and profiling.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model_runner.get_model()
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
return self.model_runner.get_kv_cache_spec()
|
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""Allocate NPU KV cache with the specified kv_cache_config."""
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
allocator = CaMemAllocator.get_instance()
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext() # type: ignore
|
|
with context:
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
|
|
def profile(self, is_start: bool = True):
|
|
if self.profiler is None:
|
|
raise RuntimeError("Profiler is not enabled.")
|
|
if is_start:
|
|
self.profiler.start()
|
|
else:
|
|
self.profiler.stop()
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.model_runner.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.model_runner.remove_lora(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.model_runner.list_loras()
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.model_runner.pin_lora(lora_id)
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
runner = self.model_runner
|
|
max_num_tokens = 1
|
|
with_prefill = False
|
|
if runner.dp_size > 1:
|
|
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
|
|
max_num_tokens, with_prefill)
|
|
if runner.torchair_graph_enabled and not with_prefill:
|
|
max_num_tokens = runner.select_torchair_padded_batch_size(
|
|
max_num_tokens)
|
|
runner._dummy_run(max_num_tokens,
|
|
is_compile=False,
|
|
with_prefill=with_prefill)
|
|
|
|
def _init_worker_distributed_environment(self) -> None:
|
|
"""Initialize the distributed environment."""
|
|
parallel_config = self.vllm_config.parallel_config
|
|
set_custom_all_reduce(
|
|
not self.parallel_config.disable_custom_all_reduce)
|
|
init_distributed_environment(self.parallel_config.world_size,
|
|
self.rank, self.distributed_init_method,
|
|
self.local_rank, "hccl")
|
|
ensure_model_parallel_initialized(
|
|
self.parallel_config.tensor_parallel_size,
|
|
self.parallel_config.pipeline_parallel_size)
|
|
init_ascend_model_parallel(
|
|
parallel_config.expert_parallel_size,
|
|
parallel_config.expert_tensor_parallel_size,
|
|
parallel_config.world_size_across_dp,
|
|
)
|
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
|
|
def _init_profiler(self):
|
|
# Torch profiler. Enabled and configured through env vars:
|
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
torch_profiler_trace_dir)
|
|
|
|
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
|
export_type=torch_npu.profiler.ExportType.Text,
|
|
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
|
|
msprof_tx=False,
|
|
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
|
l2_cache=False,
|
|
op_attr=False,
|
|
data_simplification=False,
|
|
record_op_args=False,
|
|
gc_detect_threshold=None,
|
|
)
|
|
|
|
return torch_npu.profiler.profile(
|
|
activities=[
|
|
torch_npu.profiler.ProfilerActivity.CPU,
|
|
torch_npu.profiler.ProfilerActivity.NPU,
|
|
],
|
|
with_stack=False,
|
|
profile_memory=False,
|
|
with_modules=False,
|
|
experimental_config=experimental_config,
|
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
|
torch_profiler_trace_dir))
|
|
else:
|
|
return None
|