[Feature] Support kv nz feature for DeepSeek decode node in disagg-prefill scenario (#3072)

By converting the KV cache from ND to NZ format when the decode node
receives it, this PR ensures that the KV NZ feature works correctly
during the decoding phase in disagg-prefill scenario.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: ghphotoframe <854746559@qq.com>
Co-authored-by: alex101-ops <alex1015718386@gmail.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:04 +08:00
committed by GitHub
parent a539ae753a
commit 38570cfeb6
8 changed files with 163 additions and 95 deletions

View File

@@ -412,8 +412,10 @@ class KVCacheRecvingThread(threading.Thread):
logger.debug(
f"Finished transferring KV cache for request {request_id}.")
except Exception as e:
logger.error("Failed to transfer KV cache for request "
f"{request_id}: {e}")
logger.error(
"Failed to transfer KV cache for request "
f"{request_id}: {e}",
exc_info=True)
finally:
# Always send the done signal to the remote host to ensure proper
# resource cleanup. Failing to do so may cause a memory leak on the
@@ -539,97 +541,116 @@ class KVCacheRecvingThread(threading.Thread):
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
get_ip(), self.tp_rank, session_id)
# Determine if the current position is the offset position at the end of the KV transmission.
# Determine if the current position is the offset position at the end of
# the KV transmission.
is_kv_transfer_end = (
global_offset == tp_num_need_pulls * self._prefill_pp_size - 1)
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
# need_nz_cache maybe caused error in non-MLA models
if need_cat_cache:
self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls)
need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end
if need_nz_cache or need_cat_cache:
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls,
need_cat_cache, need_nz_cache)
def _cat_kv_cache(self, block_ids: list[list[int]],
tp_num_need_pulls: int):
def reformat_kv_cache(self,
block_ids: list[list[int]],
tp_num_need_pulls: int,
need_cat_cache: bool = False,
need_nz_cache: bool = False):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
dtype = k_cache.dtype
device = k_cache.device
head_dim = self.model_config.hf_text_config.head_dim
block_size = self.vllm_config.cache_config.block_size
num_kv_head = max(
self.model_config.hf_text_config.num_key_value_heads //
self.tp_size, 1)
flat_block_ids = [item for sublist in block_ids for item in sublist]
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
block_ids_tensor = torch.tensor(flat_block_ids,
dtype=torch.int32,
device=device)
num_blocks = len(flat_block_ids)
block_len = num_blocks * block_size
num_tokens = num_blocks * self.block_size
# Create device tensors for copy operations
block_table = block_ids_tensor.view(1, -1).to(device=device)
block_len_tensor = torch.tensor([block_len],
dtype=torch.int32).to(device=device)
seq_start_tensor = torch.tensor([0],
dtype=torch.int32).to(device=device)
block_table = block_ids_tensor.view(1, -1)
block_len_tensor = torch.tensor([num_tokens],
dtype=torch.int32,
device=device)
seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device)
# Initialize buffers
k_buffer = torch.empty(block_len,
num_kv_head,
head_dim,
dtype=dtype,
device=device)
v_buffer = torch.empty(block_len,
num_kv_head,
head_dim,
dtype=dtype,
device=device)
k_buffer = torch.empty(
(num_tokens, self.num_kv_heads, self.k_head_dim),
dtype=dtype,
device=device)
v_buffer = torch.empty(
(num_tokens, self.num_kv_heads, self.v_head_dim),
dtype=dtype,
device=device)
# Create slot mapping for reshape operations
block_offsets = torch.arange(0, block_size, dtype=torch.int32)
block_offsets = torch.arange(0,
self.block_size,
dtype=torch.int32,
device=device)
slot_mapping = (block_offsets.reshape(
(1, block_size)) + block_ids_tensor.reshape(
(num_blocks, 1)) * block_size)
slot_mapping = slot_mapping.flatten().to(device=device)
(1, self.block_size)) + block_ids_tensor.reshape(
(num_blocks, 1)) * self.block_size).flatten()
# FIXME: Right now, if we skip synchronization at this point, the system
# will crash in GQA scenarios. However, we still haven't identified the
# root cause.
torch.npu.synchronize()
# Process each layer in the KV cache
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
k_cache_layer,
v_cache_layer,
block_table,
block_len_tensor,
seq_starts=seq_start_tensor,
key=k_buffer,
value=v_buffer,
)
# Transpose KV cache
k_buffer = self._transpose_kv_cache_between_head(
k_buffer, num_blocks, block_size, block_len, num_kv_head,
tp_num_need_pulls)
v_buffer = self._transpose_kv_cache_between_head(
v_buffer, num_blocks, block_size, block_len, num_kv_head,
tp_num_need_pulls)
# Reshape and cache the processed buffers
torch_npu._npu_reshape_and_cache(
key=k_buffer,
value=v_buffer,
key_cache=k_cache_layer,
value_cache=v_cache_layer,
slot_indices=slot_mapping,
)
torch_npu.atb.npu_paged_cache_load(k_cache_layer,
v_cache_layer,
block_table,
block_len_tensor,
seq_starts=seq_start_tensor,
key=k_buffer,
value=v_buffer)
if need_cat_cache:
self._cat_kv_cache(k_cache_layer, v_cache_layer, k_buffer,
v_buffer, tp_num_need_pulls, num_blocks,
num_tokens, slot_mapping)
if need_nz_cache:
self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer,
v_buffer, slot_mapping)
# Clean up buffers
del k_buffer, v_buffer
def _transpose_kv_cache_between_head(
self, buffer: torch.Tensor, num_blocks: int, block_size: int,
block_len: int, num_kv_head: int,
tp_num_need_pulls: int) -> torch.Tensor:
buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1)
buffer.transpose_(1, 2)
return buffer.contiguous().view(block_len, num_kv_head, -1)
def _cat_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer,
tp_num_need_pulls, num_blocks, num_tokens, slot_mapping):
def _transpose_kv_cache_between_head(
buffer: torch.Tensor) -> torch.Tensor:
buffer = buffer.view(num_blocks, tp_num_need_pulls,
self.block_size, -1)
buffer.transpose_(1, 2)
return buffer.contiguous().view(num_tokens, self.num_kv_heads, -1)
# Transpose KV cache
k_buffer = _transpose_kv_cache_between_head(k_buffer)
v_buffer = _transpose_kv_cache_between_head(v_buffer)
# Reshape and cache the processed buffers
torch_npu._npu_reshape_and_cache(key=k_buffer,
value=v_buffer,
key_cache=k_cache_layer,
value_cache=v_cache_layer,
slot_indices=slot_mapping)
def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer,
slot_mapping):
nz_fmt_last_dim = 16
k_cache_layer = k_cache_layer.view(
-1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim,
self.block_size, nz_fmt_last_dim)
v_cache_layer = v_cache_layer.view(
-1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim,
self.block_size, nz_fmt_last_dim)
torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer,
v_cache_layer, slot_mapping)
def _get_remote_metadata(self, remote_host: str,
remote_handshake_port: int) -> None: