[Feature] Support kv nz feature for DeepSeek decode node in disagg-prefill scenario (#3072)

By converting the KV cache from ND to NZ format when the decode node
receives it, this PR ensures that the KV NZ feature works correctly
during the decoding phase in disagg-prefill scenario.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: ghphotoframe <854746559@qq.com>
Co-authored-by: alex101-ops <alex1015718386@gmail.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:04 +08:00
committed by GitHub
parent a539ae753a
commit 38570cfeb6
8 changed files with 163 additions and 95 deletions

View File

@@ -1,5 +1,6 @@
import gc
import pytest
import torch
import torch_npu
@@ -8,8 +9,9 @@ from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"])
@torch.inference_mode()
def test_mla_preprocess_kernel():
def test_mla_preprocess_kernel(cache_mode: str):
token_num = 1
head_num = 2
N_7168 = 7168
@@ -98,7 +100,7 @@ def test_mla_preprocess_kernel():
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
cache_mode=cache_mode,
quant_mode="per_tensor_quant_asymm",
enable_inner_out=False,
q_out0=q_nope_out,

View File

@@ -1,5 +1,6 @@
import gc
import pytest
import torch
import torch_npu
@@ -8,8 +9,9 @@ from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"])
@torch.inference_mode()
def test_mla_preprocess_kernel():
def test_mla_preprocess_kernel(cache_mode: str):
token_num = 1
head_num = 2
N_7168 = 7168
@@ -82,7 +84,7 @@ def test_mla_preprocess_kernel():
None,
None,
None,
cache_mode="krope_ctkv",
cache_mode=cache_mode,
quant_mode="no_quant",
enable_inner_out=False,
q_out0=q_nope_out,

View File

@@ -1,5 +1,6 @@
import gc
import pytest
import torch
import torch_npu
@@ -8,8 +9,9 @@ from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"])
@torch.inference_mode()
def test_mla_preprocess_kernel():
def test_mla_preprocess_kernel(cache_mode: str):
token_num = 1
head_num = 2
N_7168 = 7168
@@ -99,7 +101,7 @@ def test_mla_preprocess_kernel():
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
cache_mode=cache_mode,
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=q_nope_out,