From 538a69c1459cc8fde032b8db211ea215c78063b9 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 22 Apr 2025 14:13:00 +0800 Subject: [PATCH] [Patch] format patch module to make it more clear (#601) Format patch module to make it more clear. Add the patch doc description, the new patch must follow this guide. Signed-off-by: wangxiyuan --- vllm_ascend/patch/__init__.py | 134 +++++++++++++++++- .../patch/platform/patch_0_8_4/__init__.py | 12 -- .../patch/platform/patch_common/__init__.py | 4 +- .../patch_distributed.py | 34 ----- .../patch/platform/patch_main/__init__.py | 1 - .../platform/patch_main/patch_distributed.py | 32 ----- .../patch/worker/patch_common/__init__.py | 60 -------- 7 files changed, 136 insertions(+), 141 deletions(-) rename vllm_ascend/patch/platform/{patch_0_8_4 => patch_common}/patch_distributed.py (74%) delete mode 100644 vllm_ascend/patch/platform/patch_main/patch_distributed.py diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 2ed088b..6de4cd1 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -13,4 +13,136 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file + +# ---------------------------------------------------------------------------------- +# This module manage the patch for vllm. There are two folders in this module: +# - platform: contains the patches applied before worker starts. It's called by +# `vllm_ascend.utils.adapt_patch(is_global_patch=True)` in +# `vllm_ascend.platform.NPUPlatform.pre_register_and_update()` function. +# - worker: contains the patches applied when worker starts. It's called by +# `vllm_ascend.utils.adapt_patch(is_global_patch=False)` in +# each worker's `__init__` function. +# +# Then in each kind of patch, there are three folders: +# - patch_0_8_4: contains the patches applied when vllm version is 0.8.4. +# - patch_main: contains the patches applied when vllm version is main branch. +# - patch_common: contains the patches applied in both 0.8.4 and main branch. +# +# In the future, with the vllm version upgrade, the new patch folder such as +# patch_0_8_5, patch_0_8_6, etc. will be added to manage the patch for different +# vllm version. And the patch_common will contain the patches applied in all the +# vllm version. +# Once the vllm version is too old that vllm-ascend will not support, the related +# patch folder will be removed as well. +# +# Once a new patch is added in vllm-ascend, please add the patch description into this file as well. +# ---------------------------------------------------------------------------------- + +# What's Patched and how it works: +# -------------------------------- +# * Platform Patch: +# ================= +# ** File: platform/patch_0_8_4/patch_config.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.config.ModelConfig.__init__()` +# Why: +# It is hard coded for sleep mode to support cuda platform only +# How: +# Using a new method to check if sleep mode is available +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# https://github.com/vllm-project/vllm/pull/16562 +# Future Plan: +# This patch is only used for 084 and can't be revert. just keep as it is. +# +# ** File: platform/patch_common/patch_distributed.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.distributed.parallel_state.destroy_model_parallel()` +# Why: +# vllm dose not support outside platform maintain its own `CoordinatorGroup`, vllm-ascend maintain EP and ETP +# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy +# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed +# How: +# Call platform method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup` +# Related PR (if no, explain why): no related PR, we want add this ability into vllm +# Future Plan: +# Remove those patch when vllm merged them +# 2. `vllm.distributed.stateless_init_torch_distributed_process_group()` +# Why: +# The stateless process group can not be initialized except from gloo and nccl backend, vllm-ascend +# needs to initialize its own stateless process group for communication, so we add the platform related +# call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support +# stateless process group initialize method +# How: +# Call platform method `platform_has_backend_register` to judge if there is a stateless process group initialize +# method and call platform method `platform_register_backend` to initialize them +# Related PR (if no, explain why): no related PR, we want add this ability into vllm +# Future Plan: +# Remove those patch when vllm merged them +# 3. `ParallelConfig.get_next_dp_init_port` +# Why: +# We want to get dp port from env variable, so the multi-node inference can be properly initialized and run. +# How: +# Get the dp port from env variable enable multi-mode dp inference +# Related PR (if no, explain why): no related PR, we want add this ability into vllm +# Future Plan: +# Its a workaround in vllm-ascend to enable multi-node dp inference, maybe removed if vllm have better plan +# on multi-node dp inference implementation +# +# * Worker Patch: +# =============== +# ** File: worker/patch_common/patch_metrics.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and +# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async` +# Why: +# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and +# `AsyncMetricsCollector._copy_rejsample_metrics_async` +# How: +# Replace it with the corresponding npu method +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# https://github.com/vllm-project/vllm/pull/14411 +# Future Plan: +# Revert it when the related pr is merged in vllm. +# +# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics` +# Why: +# There are cuda hard code (current_platform.is_cuda_alike()) in +# `AsyncMetricsCollector.maybe_collect_rejsample_metrics` +# How: +# Change to use `current_platform.Event` to determine whether to return None +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# https://github.com/vllm-project/vllm/pull/14411 +# Future Plan: +# Revert it when the related pr is merged in vllm. +# +# ** File: worker/patch_common/patch_multi_step_worker.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output` +# Why: +# There are cuda hard code (current_platform.is_cuda_alike()) in +# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it. +# How: +# Make speculative decoding extensible to different backends. +# - support attention metadata register to the set supported spec decode +# - offer a api in platform to determine whether spec decode is supported, +# and deprecate is_cuda_alike in it. +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - https://github.com/vllm-project/vllm/pull/15195 +# - https://github.com/vllm-project/vllm-ascend/pull/395 +# Future Plan: +# Revert it when the related pr is merged in vllm and vllm-ascend. +# +# ** File: worker/patch_common/patch_multi_step_worker.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker` +# Why: +# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`. +# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of +# `FlashAttentionMetadata` +# How: +# ditto +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - https://github.com/vllm-project/vllm/pull/15195 +# - https://github.com/vllm-project/vllm-ascend/pull/395 +# Future Plan: +# Revert it when the related pr is merged in vllm and vllm-ascend. diff --git a/vllm_ascend/patch/platform/patch_0_8_4/__init__.py b/vllm_ascend/patch/platform/patch_0_8_4/__init__.py index 9e6adeb..a058380 100644 --- a/vllm_ascend/patch/platform/patch_0_8_4/__init__.py +++ b/vllm_ascend/patch/platform/patch_0_8_4/__init__.py @@ -14,17 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# What's Patched and how it works: -# ** File: platform/patch_0_8_4/patch_config.py** -# 1. `vllm.config.ModelConfig.__init__()` -# Why: -# It is hard coded for sleep mode to support cuda platform only -# How: -# Using a new method to check if sleep mode is available -# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... -# https://github.com/vllm-project/vllm/pull/16562 -# Future Plan: -# This patch is only used for 084 and can't be revert. just keep as it is. import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa -import vllm_ascend.patch.platform.patch_0_8_4.patch_distributed # noqa diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index 2ed088b..f88f2a9 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -13,4 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# + +import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa diff --git a/vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py similarity index 74% rename from vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py rename to vllm_ascend/patch/platform/patch_common/patch_distributed.py index 3efbd45..1b356a9 100644 --- a/vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -27,40 +27,6 @@ from torch.distributed.distributed_c10d import (Backend, PrefixStore, from torch.distributed.rendezvous import rendezvous from vllm.config import ParallelConfig -# What's Patched and how it works: -# ** File: platform/patch_0_8_4/patch_distributed.py** -# 1. `vllm.distributed.parallel_state.destroy_model_parallel()` -# Why: -# vllm dose not support outside platform maintain its own `CoordinatorGroup`, vllm-ascend maintain EP and ETP -# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy -# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed -# How: -# Call platform method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup` -# Related PR (if no, explain why): no related PR, we want add this ability into vllm -# Future Plan: -# Remove those patch when vllm merged them -# 2. `vllm.distributed.stateless_init_torch_distributed_process_group()` -# Why: -# The stateless process group can not be initialized except from gloo and nccl backend, vllm-ascend -# needs to initialize its own stateless process group for communication, so we add the platform related -# call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support -# stateless process group initialize method -# How: -# Call platform method `platform_has_backend_register` to judge if there is a stateless process group initialize -# method and call platform method `platform_register_backend` to initialize them -# Related PR (if no, explain why): no related PR, we want add this ability into vllm -# Future Plan: -# Remove those patch when vllm merged them -# 3. `ParallelConfig.get_next_dp_init_port` -# Why: -# We want to get dp port from env variable, so the multi-node inference can be properly initialized and run. -# How: -# Get the dp port from env variable enable multi-mode dp inference -# Related PR (if no, explain why): no related PR, we want add this ability into vllm -# Future Plan: -# Its a workaround in vllm-ascend to enable multi-node dp inference, maybe removed if vllm have better plan -# on multi-node dp inference implementation - def ascend_destroy_model_parallel(): """Set the groups to none and destroy them.""" diff --git a/vllm_ascend/patch/platform/patch_main/__init__.py b/vllm_ascend/patch/platform/patch_main/__init__.py index d430dbe..116c73c 100644 --- a/vllm_ascend/patch/platform/patch_main/__init__.py +++ b/vllm_ascend/patch/platform/patch_main/__init__.py @@ -14,4 +14,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import vllm_ascend.patch.platform.patch_main.patch_distributed # noqa F401 \ No newline at end of file diff --git a/vllm_ascend/patch/platform/patch_main/patch_distributed.py b/vllm_ascend/patch/platform/patch_main/patch_distributed.py deleted file mode 100644 index bdac50b..0000000 --- a/vllm_ascend/patch/platform/patch_main/patch_distributed.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from vllm/model_executor/models/qwen2_vl.py -# This file is a part of the vllm-ascend project. - -import vllm -import vllm.distributed -from vllm.config import ParallelConfig - -from vllm_ascend.patch.platform.patch_0_8_4.patch_distributed import ( - ascend_destroy_model_parallel, - ascend_stateless_init_torch_distributed_process_group, - parallel_config_get_dp_port) - -# All details of those patch please refer to vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py -vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel -vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group -ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 5e5e44c..590074f 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -15,66 +15,6 @@ # limitations under the License. # -# What's Patched and how it works: -# ** File: worker/patch_common/patch_metrics.py ** -# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and -# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async` -# Why: -# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and -# `AsyncMetricsCollector._copy_rejsample_metrics_async` -# How: -# Replace it with the corresponding npu method -# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... -# https://github.com/vllm-project/vllm/pull/14411 -# Future Plan: -# Revert it when the related pr is merged in vllm. -# -# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics` -# Why: -# There are cuda hard code (current_platform.is_cuda_alike()) in -# `AsyncMetricsCollector.maybe_collect_rejsample_metrics` -# How: -# Change to use `current_platform.Event` to determine whether to return None -# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... -# https://github.com/vllm-project/vllm/pull/14411 -# Future Plan: -# Revert it when the related pr is merged in vllm. -# -# ** File: worker/patch_common/patch_multi_step_worker.py ** -# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output` -# Why: -# There are cuda hard code (current_platform.is_cuda_alike()) in -# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it. -# How: -# Make speculative decoding extensible to different backends. -# - support attention metadata register to the set supported spec decode -# - offer a api in platform to determine whether spec decode is supported, -# and deprecate is_cuda_alike in it. -# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... -# - https://github.com/vllm-project/vllm/pull/15195 -# - https://github.com/vllm-project/vllm-ascend/pull/395 -# Future Plan: -# Revert it when the related pr is merged in vllm and vllm-ascend. -# -# ** File: worker/patch_common/patch_multi_step_worker.py ** -# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker` -# Why: -# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`. -# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of -# `FlashAttentionMetadata` -# How: -# ditto -# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... -# - https://github.com/vllm-project/vllm/pull/15195 -# - https://github.com/vllm-project/vllm-ascend/pull/395 -# Future Plan: -# Revert it when the related pr is merged in vllm and vllm-ascend. - -# current_platform.is_cuda_alike() -# 0.8.4 patch doc: -# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common -# ... - import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa