[Kernel] Add AscendC fused op transpose_kv_cache_by_block to speed up GQA transfer (#6366)

### What this PR does / why we need it?
As #2947 describe, we need to transpose kv cache layout after GQA kv
transfer when prefill and decode tensor parallel size are heterogeneous,
in the previous implementation, we use `npu_paged_cache_load ` +
`tranpose` + `_npu_reshape_and_cache` to do this work.

But obviously, it is not an efficient plan, the ops above need to be
called for each layer, which introduces 3 * layer_num kernel launch, and
6 * layer_num data movement between L1 Cache and HBM for one request on
decode node. Usually, decode node uses graph mode, so these op kernels
will be called between decode forward launched by an async thread in
mooncacke connector, this kernels maybe last for several decode forward
and TTFT will increase by 3~4 decode forward time.

In this PR, we implement an AscendC fused op
`transpose_kv_cache_by_block` to do this with only once kernel launch
and move data between L1 Cache and HBM only once.

After using this fused op, the time cost in transpose kv cacke layout
can be decreased to 0.24ms from 7ms in UT on 910C, and in PD
disaggregation scenario, TTFT can decrease about 90 ~ 110 ms in
qwen3-235B.

| request_num | original | fused_op|
|:----------------------:|:---------------:|:-------------------:|
|           1            |      643 ms      |        578 ms        |
|          128           |     1480 ms      |       1368 ms        |

### Does this PR introduce _any_ user-facing change?
Use fused op by default, incase the op has bug in any scenario, provide
fallback choice using env to disable it.

**DISABLE fused op by add following env**
`export VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK=0`

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2026-02-03 14:10:01 +08:00
committed by GitHub
parent f4a72f0d16
commit 79803932e2
15 changed files with 913 additions and 3 deletions

View File

@@ -46,10 +46,11 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import RequestStatus
from vllm_ascend import envs as ascend_envs
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.kv_transfer.utils.utils import get_transfer_timeout_value
from vllm_ascend.utils import is_vl_model
from vllm_ascend.utils import enable_custom_op, is_vl_model
# isort: off
if TYPE_CHECKING:
@@ -570,8 +571,39 @@ class KVCacheRecvingThread(threading.Thread):
is_kv_transfer_end = global_offset == tp_num_need_pulls * self._prefill_pp_size - 1
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end
use_fused_op = ascend_envs.VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK
if need_nz_cache or need_cat_cache:
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache)
# use fused op to reformat kv cache, we keep original implementation to provide ability to disable it.
if use_fused_op and enable_custom_op():
if need_cat_cache:
# the fused op only support cat GQA/MHA kv cache by head
self.reformat_kv_cache_with_fused_op(grouped_local_block_ids, tp_num_need_pulls)
if need_nz_cache:
# maybe use fused op to reformat kv nz too in the future.
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, False, need_nz_cache)
else:
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache)
def reformat_kv_cache_with_fused_op(self, block_ids: list[list[int]], tp_num_need_pulls: int):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
device = k_cache.device
head_dim = self.model_config.hf_text_config.head_dim
block_size = self.vllm_config.cache_config.block_size
num_kv_head = max(self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1)
layers = self.model_config.hf_text_config.num_hidden_layers
flat_block_ids = [item for sublist in block_ids for item in sublist]
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int64, device=device)
k_caches = []
v_caches = []
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
k_caches.append(k_cache_layer)
v_caches.append(v_cache_layer)
torch.ops._C_ascend.transpose_kv_cache_by_block(
k_caches, v_caches, block_ids_tensor, block_size, num_kv_head, head_dim, tp_num_need_pulls, layers
)
def reformat_kv_cache(
self,