### What this PR does / why we need it?
Now `VLLM_ASCEND_ENABLE_NZ` will have three options:
0: disable nz;
1: only quant case enable nz;
2: enable nz as long as possible;
And `VLLM_ASCEND_ENABLE_NZ`=1 by default.
All cases are shown in the table below:
| | W4A4 | W4A8 | W8A8 | fp16/bf16 | fp32 |
|---|---|---|---|---|---|
| trans nz | can't support nz | trans nz by default | trans nz by
default | trans nz when VLLM_ASCEND_ENABLE_NZ is 2 | can't support nz |
| transpose | only support not transpose case | only support transpose
case | only support transpose case | linear: only support not transpose
case<br>gmm: only support transpose case | same to fp16/bf16 |
Some exceptional cases:
1. MLAPO op need to do some additional processing on the weights,
including trans nz. If use MLAPO op, some weight will be transformed to
nz forcely;
2. MLA/SFA's weight `W_UV` will be used by op
`torch.ops._C_ascend.batch_matmul_transpose`, and this op can't support
nz currently;
### Does this PR introduce _any_ user-facing change?
Now fp16/bf16 weight will not trans nz by default.
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zzzzwwjj <1183291235@qq.com>
155 lines
7.7 KiB
Python
155 lines
7.7 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
from typing import Any, Callable, Dict
|
|
|
|
# The begin-* and end* here are used by the documentation generator
|
|
# to extract the used env vars.
|
|
|
|
# begin-env-vars-definition
|
|
|
|
env_variables: Dict[str, Callable[[], Any]] = {
|
|
# max compile thread number for package building. Usually, it is set to
|
|
# the number of CPU cores. If not set, the default value is None, which
|
|
# means all number of CPU cores will be used.
|
|
"MAX_JOBS":
|
|
lambda: os.getenv("MAX_JOBS", None),
|
|
# The build type of the package. It can be one of the following values:
|
|
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
|
"CMAKE_BUILD_TYPE":
|
|
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
|
# The CXX compiler used for compiling the package. If not set, the default
|
|
# value is None, which means the system default CXX compiler will be used.
|
|
"CXX_COMPILER":
|
|
lambda: os.getenv("CXX_COMPILER", None),
|
|
# The C compiler used for compiling the package. If not set, the default
|
|
# value is None, which means the system default C compiler will be used.
|
|
"C_COMPILER":
|
|
lambda: os.getenv("C_COMPILER", None),
|
|
# The version of the Ascend chip. It's used for package building.
|
|
# If not set, we will query chip info through `npu-smi`.
|
|
# Please make sure that the version is correct.
|
|
"SOC_VERSION":
|
|
lambda: os.getenv("SOC_VERSION", None),
|
|
# If set, vllm-ascend will print verbose logs during compilation
|
|
"VERBOSE":
|
|
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
|
# The home path for CANN toolkit. If not set, the default value is
|
|
# /usr/local/Ascend/ascend-toolkit/latest
|
|
"ASCEND_HOME_PATH":
|
|
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
|
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
|
# not set, the default value is libhccl.so.
|
|
"HCCL_SO_PATH":
|
|
lambda: os.environ.get("HCCL_SO_PATH", None),
|
|
# The version of vllm is installed. This value is used for developers who
|
|
# installed vllm from source locally. In this case, the version of vllm is
|
|
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
|
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
|
# In this case, developers need to set this value to "0.9.0" to make sure
|
|
# that the correct package is installed.
|
|
"VLLM_VERSION":
|
|
lambda: os.getenv("VLLM_VERSION", None),
|
|
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
|
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
|
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
|
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
|
),
|
|
# Whether to enable the model execute time observe profile. Disable it when
|
|
# running vllm ascend in production environment.
|
|
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
|
),
|
|
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
|
# training, the optimized model may not be suitable. In this case, set this
|
|
# value to False to disable the optimized model.
|
|
"USE_OPTIMIZED_MODEL":
|
|
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
|
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
|
|
# this feature is supported in A2, and eager mode will get better performance.
|
|
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
|
# Whether to enable FlashComm optimization when tensor parallel is enabled.
|
|
# This feature will get better performance when concurrency is large.
|
|
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
|
|
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
|
|
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
|
|
# For a detailed introduction to the parameters and the differences and applicable scenarios
|
|
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
|
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
|
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
|
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
|
|
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
|
|
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
|
|
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
|
|
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
|
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
|
# buffer size for gate up prefetch
|
|
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE":
|
|
lambda: int(
|
|
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
|
# buffer size for down proj prefetch
|
|
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE":
|
|
lambda: int(
|
|
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
|
# Whether to enable dense model and general optimizations for better performance.
|
|
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
|
|
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
|
|
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
|
|
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
|
"MSMONITOR_USE_DAEMON":
|
|
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
|
|
"VLLM_ASCEND_ENABLE_MLAPO":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
|
|
# Whether to enable weight cast format to FRACTAL_NZ.
|
|
# 0: close nz;
|
|
# 1: only quant case enable nz;
|
|
# 2: enable nz as long as possible.
|
|
"VLLM_ASCEND_ENABLE_NZ":
|
|
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
|
|
# Decide whether we should enable CP parallelism.
|
|
"VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL":
|
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0'))),
|
|
# Whether to anbale dynamic EPLB
|
|
"DYNAMIC_EPLB":
|
|
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
|
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
|
|
"VLLM_ASCEND_ENABLE_FUSED_MC2":
|
|
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
|
|
}
|
|
|
|
# end-env-vars-definition
|
|
|
|
|
|
def __getattr__(name: str):
|
|
# lazy evaluation of environment variables
|
|
if name in env_variables:
|
|
return env_variables[name]()
|
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
|
|
|
|
def __dir__():
|
|
return list(env_variables.keys())
|