### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.
But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.
In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.
After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.
| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
| 1 | 643 ms | 578 ms |
| 128 | 1480 ms | 1368 ms |
### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.
**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
132 lines
7.3 KiB
Python
132 lines
7.3 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
# The begin-* and end* here are used by the documentation generator
|
|
# to extract the used env vars.
|
|
|
|
# begin-env-vars-definition
|
|
|
|
env_variables: dict[str, Callable[[], Any]] = {
|
|
# max compile thread number for package building. Usually, it is set to
|
|
# the number of CPU cores. If not set, the default value is None, which
|
|
# means all number of CPU cores will be used.
|
|
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
|
|
# The build type of the package. It can be one of the following values:
|
|
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
|
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
|
# The CXX compiler used for compiling the package. If not set, the default
|
|
# value is None, which means the system default CXX compiler will be used.
|
|
"CXX_COMPILER": lambda: os.getenv("CXX_COMPILER", None),
|
|
# The C compiler used for compiling the package. If not set, the default
|
|
# value is None, which means the system default C compiler will be used.
|
|
"C_COMPILER": lambda: os.getenv("C_COMPILER", None),
|
|
# The version of the Ascend chip. It's used for package building.
|
|
# If not set, we will query chip info through `npu-smi`.
|
|
# Please make sure that the version is correct.
|
|
"SOC_VERSION": lambda: os.getenv("SOC_VERSION", None),
|
|
# If set, vllm-ascend will print verbose logs during compilation
|
|
"VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))),
|
|
# The home path for CANN toolkit. If not set, the default value is
|
|
# /usr/local/Ascend/ascend-toolkit/latest
|
|
"ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None),
|
|
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
|
# not set, the default value is libhccl.so.
|
|
"HCCL_SO_PATH": lambda: os.environ.get("HCCL_SO_PATH", None),
|
|
# The version of vllm is installed. This value is used for developers who
|
|
# installed vllm from source locally. In this case, the version of vllm is
|
|
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
|
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
|
# In this case, developers need to set this value to "0.9.0" to make sure
|
|
# that the correct package is installed.
|
|
"VLLM_VERSION": lambda: os.getenv("VLLM_VERSION", None),
|
|
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
|
# training, the optimized model may not be suitable. In this case, set this
|
|
# value to False to disable the optimized model.
|
|
"USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv("USE_OPTIMIZED_MODEL", "1"))),
|
|
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
|
|
# this feature is supported in A2, and eager mode will get better performance.
|
|
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", "0"))),
|
|
# Whether to enable FlashComm optimization when tensor parallel is enabled.
|
|
# This feature will get better performance when concurrency is large.
|
|
"VLLM_ASCEND_ENABLE_FLASHCOMM1": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))),
|
|
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
|
|
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
|
|
# For a detailed introduction to the parameters and the differences and applicable scenarios
|
|
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
|
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
|
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
|
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))),
|
|
# buffer size for gate up prefetch
|
|
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int(
|
|
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)
|
|
),
|
|
# buffer size for down proj prefetch
|
|
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int(
|
|
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)
|
|
),
|
|
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
|
"MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))),
|
|
# Whether to enable MLAPO optimization for DeepSeek W8A8 series models.
|
|
# This option is enabled by default. MLAPO can improve performance, but
|
|
# it will consume more NPU memory. If reducing NPU memory usage is a higher priority
|
|
# for your DeepSeek W8A8 scene, then disable it.
|
|
"VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", "1"))),
|
|
# Whether to enable weight cast format to FRACTAL_NZ.
|
|
# 0: close nz;
|
|
# 1: only quant case enable nz;
|
|
# 2: enable nz as long as possible.
|
|
"VLLM_ASCEND_ENABLE_NZ": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
|
|
# Decide whether we should enable CP parallelism.
|
|
"VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", "0"))),
|
|
# Whether to anbale dynamic EPLB
|
|
"DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
|
# Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator)
|
|
# 0, or not set: default ALLTOALL and MC2 will be used.
|
|
# 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator.
|
|
# `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=32, non-mtp, non-dynamic-eplb.
|
|
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
|
|
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
|
|
# with W8A8. And MTP layer must be W8A8.
|
|
"VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")),
|
|
# Whether to anbale balance scheduling
|
|
"VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))),
|
|
# use fused op transpose_kv_cache_by_block, default is True
|
|
"VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool(
|
|
int(os.getenv("VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK", "1"))
|
|
),
|
|
}
|
|
|
|
# end-env-vars-definition
|
|
|
|
|
|
def __getattr__(name: str):
|
|
# lazy evaluation of environment variables
|
|
if name in env_variables:
|
|
return env_variables[name]()
|
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
|
|
|
|
def __dir__():
|
|
return list(env_variables.keys())
|