From 8e0ebb470a3174bcced8ff13b33627c9673c93e3 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Thu, 19 Mar 2026 14:27:27 +0800 Subject: [PATCH] [Misc] Drop Prefetch MLP Env (#7357) ### What this PR does / why we need it? remove deprecated environment variables related to MLP prefetching ### Does this PR introduce _any_ user-facing change? yes, the deprecated env vars can not be used then. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: wangxiyuan --- .../features/suffix_speculative_decoding.md | 4 +- vllm_ascend/ascend_config.py | 39 ++----------------- vllm_ascend/envs.py | 10 ----- vllm_ascend/ops/weight_prefetch.py | 22 +++-------- 4 files changed, 10 insertions(+), 65 deletions(-) diff --git a/docs/source/tutorials/features/suffix_speculative_decoding.md b/docs/source/tutorials/features/suffix_speculative_decoding.md index bad8ea61..18b8b9e2 100644 --- a/docs/source/tutorials/features/suffix_speculative_decoding.md +++ b/docs/source/tutorials/features/suffix_speculative_decoding.md @@ -80,8 +80,6 @@ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 export TASK_QUEUE_ENABLE=1 # Enable the AIVector core to directly schedule ROCE communication. export HCCL_OP_EXPANSION_MODE="AIV" -# Enable MLP prefetch for better performance. -export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1 # Enable FlashComm_v1 optimization when tensor parallel is enabled. export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 @@ -94,7 +92,7 @@ vllm serve /data/Qwen3-32B \ --max-num-batched-tokens 40960 \ --speculative-config '{"method": "suffix", "num_speculative_tokens": 3}' \ --gpu-memory-utilization 0.9 \ - --additional-config '{"pa_shape_list":[48,64,72,80]}' \ + --additional-config '{"pa_shape_list":[48,64,72,80], "weight_prefetch_config":{"enable":true}}' \ --port 8011 ``` diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index ebd6c5aa..413c2bd0 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import warnings from typing import TYPE_CHECKING from vllm.logger import logger @@ -48,9 +47,11 @@ class AscendConfig: eplb_config = additional_config.get("eplb_config", {}) self.eplb_config = EplbConfig(eplb_config) + weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) + self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) + # Dump / PrecisionDebugger configuration self.dump_config_path = additional_config.get("dump_config_path", None) - self._construct_weight_prefetch_config(additional_config) self.layer_sharding = additional_config.get("layer_sharding", None) if self.layer_sharding: logger.info_once( @@ -158,29 +159,6 @@ class AscendConfig: and get_ascend_device_type() != AscendDeviceType.A5 ) - def _construct_weight_prefetch_config(self, additional_config): - weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) - self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) - # Deprecated env var handling for backward compatibility - if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1": - MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024 - gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE)) - down_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE)) - self.weight_prefetch_config.set_mlp_pre_version_compatibale_config( - gate_up_prefetch_size, down_prefetch_size - ) - logger.info_once( - f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP." - f"gate_up_prefetch_size={gate_up_prefetch_size}, " - f"down_prefetch_size={down_prefetch_size}." - ) - warnings.warn( - "VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. " - "Please use weight_prefetch_config in additional-config for now instead.", - DeprecationWarning, - stacklevel=2, - ) - @staticmethod def _get_compile_ranges(compilation_config): from vllm_ascend.utils import vllm_version_is @@ -380,28 +358,19 @@ class WeightPrefetchConfig: Configuration Object for weight_prefetch_config from additional_config """ - mlp_pre_version_compatibale_config: dict = {} - prefetch_ratio: dict = { "attn": { "qkv": 1.0, "o": 1.0, }, "moe": {"gate_up": 0.8}, - "mlp": {"gate_up": 1, "down": 1.0}, + "mlp": {"gate_up": 1.0, "down": 1.0}, } def __init__(self, weight_prefetch_config: dict): self.enabled = weight_prefetch_config.get("enabled", False) self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio) - def set_mlp_pre_version_compatibale_config(self, gate_up_prefetch_size: int, down_prefetch_size: int): - config = { - "gate_up": gate_up_prefetch_size, - "down": down_prefetch_size, - } - self.mlp_pre_version_compatibale_config = config - class EplbConfig: """ diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 9fafb2ed..b161220e 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -77,16 +77,6 @@ env_variables: dict[str, Callable[[], Any]] = { # For a detailed introduction to the parameters and the differences and applicable scenarios # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), - # Whether to enable MLP weight prefetch, only used in small concurrency. - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))), - # buffer size for gate up prefetch - "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int( - os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024) - ), - # buffer size for down proj prefetch - "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int( - os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024) - ), # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))), # Whether to enable MLAPO optimization for DeepSeek W8A8 series models. diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index d51b711f..09a052cd 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -43,9 +43,6 @@ class WeightPrefetchMethod: MLP_GATE_UP: str = "gate_up" MLP_DOWN: str = "down" - # backward compatibility: delete in future versions - mlp_pre_version_compatibale_config: dict = {} - def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: self.is_moe = is_moe_model(get_current_vllm_config()) self.mla_sfa_prefetch_enable = weight_prefetch_config.enabled @@ -70,7 +67,6 @@ class WeightPrefetchMethod: enable=weight_prefetch_config.enabled and not self.is_moe, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("mlp", {}) or {"gate_up": 1.0, "down": 1.0}, ) - self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config def maybe_prefetch_attn_weight_preprocess( self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor @@ -114,7 +110,7 @@ class WeightPrefetchMethod: def maybe_prefetch_mlp_weight_preprocess( self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None ): - if not self.mlp.enable and not self.mlp_pre_version_compatibale_config: + if not self.mlp.enable: self.mlp.is_active_this_forward = False return @@ -146,12 +142,9 @@ class WeightPrefetchMethod: model_instance = _EXTRA_CTX.model_instance layer_idx = int(curr_layer_prefix.split(".")[2]) weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore - if self.mlp_pre_version_compatibale_config: - weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0) - else: - weight_size = ( - weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) - ) + weight_size = ( + weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) + ) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) @@ -161,12 +154,7 @@ class WeightPrefetchMethod: layer_idx = _EXTRA_CTX.layer_idx model_instance = _EXTRA_CTX.model_instance weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore - if self.mlp_pre_version_compatibale_config: - weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0) - else: - weight_size = ( - weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) - ) + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))