[Misc] Drop Prefetch MLP Env (#7357)
### What this PR does / why we need it?
remove deprecated environment variables related to MLP prefetching
### Does this PR introduce _any_ user-facing change?
yes, the deprecated env vars can not be used then.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -80,8 +80,6 @@ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
|
||||
export TASK_QUEUE_ENABLE=1
|
||||
# Enable the AIVector core to directly schedule ROCE communication.
|
||||
export HCCL_OP_EXPANSION_MODE="AIV"
|
||||
# Enable MLP prefetch for better performance.
|
||||
export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1
|
||||
# Enable FlashComm_v1 optimization when tensor parallel is enabled.
|
||||
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
|
||||
|
||||
@@ -94,7 +92,7 @@ vllm serve /data/Qwen3-32B \
|
||||
--max-num-batched-tokens 40960 \
|
||||
--speculative-config '{"method": "suffix", "num_speculative_tokens": 3}' \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--additional-config '{"pa_shape_list":[48,64,72,80]}' \
|
||||
--additional-config '{"pa_shape_list":[48,64,72,80], "weight_prefetch_config":{"enable":true}}' \
|
||||
--port 8011
|
||||
```
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import logger
|
||||
@@ -48,9 +47,11 @@ class AscendConfig:
|
||||
eplb_config = additional_config.get("eplb_config", {})
|
||||
self.eplb_config = EplbConfig(eplb_config)
|
||||
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
|
||||
# Dump / PrecisionDebugger configuration
|
||||
self.dump_config_path = additional_config.get("dump_config_path", None)
|
||||
self._construct_weight_prefetch_config(additional_config)
|
||||
self.layer_sharding = additional_config.get("layer_sharding", None)
|
||||
if self.layer_sharding:
|
||||
logger.info_once(
|
||||
@@ -158,29 +159,6 @@ class AscendConfig:
|
||||
and get_ascend_device_type() != AscendDeviceType.A5
|
||||
)
|
||||
|
||||
def _construct_weight_prefetch_config(self, additional_config):
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
# Deprecated env var handling for backward compatibility
|
||||
if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1":
|
||||
MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024
|
||||
gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
||||
down_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
||||
self.weight_prefetch_config.set_mlp_pre_version_compatibale_config(
|
||||
gate_up_prefetch_size, down_prefetch_size
|
||||
)
|
||||
logger.info_once(
|
||||
f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP."
|
||||
f"gate_up_prefetch_size={gate_up_prefetch_size}, "
|
||||
f"down_prefetch_size={down_prefetch_size}."
|
||||
)
|
||||
warnings.warn(
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. "
|
||||
"Please use weight_prefetch_config in additional-config for now instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_compile_ranges(compilation_config):
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
@@ -380,28 +358,19 @@ class WeightPrefetchConfig:
|
||||
Configuration Object for weight_prefetch_config from additional_config
|
||||
"""
|
||||
|
||||
mlp_pre_version_compatibale_config: dict = {}
|
||||
|
||||
prefetch_ratio: dict = {
|
||||
"attn": {
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
"moe": {"gate_up": 0.8},
|
||||
"mlp": {"gate_up": 1, "down": 1.0},
|
||||
"mlp": {"gate_up": 1.0, "down": 1.0},
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
|
||||
|
||||
def set_mlp_pre_version_compatibale_config(self, gate_up_prefetch_size: int, down_prefetch_size: int):
|
||||
config = {
|
||||
"gate_up": gate_up_prefetch_size,
|
||||
"down": down_prefetch_size,
|
||||
}
|
||||
self.mlp_pre_version_compatibale_config = config
|
||||
|
||||
|
||||
class EplbConfig:
|
||||
"""
|
||||
|
||||
@@ -77,16 +77,6 @@ env_variables: dict[str, Callable[[], Any]] = {
|
||||
# For a detailed introduction to the parameters and the differences and applicable scenarios
|
||||
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
||||
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))),
|
||||
# buffer size for gate up prefetch
|
||||
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int(
|
||||
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)
|
||||
),
|
||||
# buffer size for down proj prefetch
|
||||
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int(
|
||||
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)
|
||||
),
|
||||
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
||||
"MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))),
|
||||
# Whether to enable MLAPO optimization for DeepSeek W8A8 series models.
|
||||
|
||||
@@ -43,9 +43,6 @@ class WeightPrefetchMethod:
|
||||
MLP_GATE_UP: str = "gate_up"
|
||||
MLP_DOWN: str = "down"
|
||||
|
||||
# backward compatibility: delete in future versions
|
||||
mlp_pre_version_compatibale_config: dict = {}
|
||||
|
||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||
self.is_moe = is_moe_model(get_current_vllm_config())
|
||||
self.mla_sfa_prefetch_enable = weight_prefetch_config.enabled
|
||||
@@ -70,7 +67,6 @@ class WeightPrefetchMethod:
|
||||
enable=weight_prefetch_config.enabled and not self.is_moe,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("mlp", {}) or {"gate_up": 1.0, "down": 1.0},
|
||||
)
|
||||
self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor
|
||||
@@ -114,7 +110,7 @@ class WeightPrefetchMethod:
|
||||
def maybe_prefetch_mlp_weight_preprocess(
|
||||
self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None
|
||||
):
|
||||
if not self.mlp.enable and not self.mlp_pre_version_compatibale_config:
|
||||
if not self.mlp.enable:
|
||||
self.mlp.is_active_this_forward = False
|
||||
return
|
||||
|
||||
@@ -146,12 +142,9 @@ class WeightPrefetchMethod:
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
layer_idx = int(curr_layer_prefix.split(".")[2])
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
||||
else:
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
|
||||
)
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
|
||||
)
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
@@ -161,12 +154,7 @@ class WeightPrefetchMethod:
|
||||
layer_idx = _EXTRA_CTX.layer_idx
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
||||
else:
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
|
||||
)
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
|
||||
Reference in New Issue
Block a user